diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 1bb7d06232..94e857f93a 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -36,7 +36,6 @@ /api/core/workflow/graph/ @laipz8200 @QuantumGhost /api/core/workflow/graph_events/ @laipz8200 @QuantumGhost /api/core/workflow/node_events/ @laipz8200 @QuantumGhost -/api/dify_graph/model_runtime/ @laipz8200 @QuantumGhost # Backend - Workflow - Nodes (Agent, Iteration, Loop, LLM) /api/core/workflow/nodes/agent/ @Nov1c444 diff --git a/.github/workflows/api-tests.yml b/.github/workflows/api-tests.yml index 6b87946221..7bce056970 100644 --- a/.github/workflows/api-tests.yml +++ b/.github/workflows/api-tests.yml @@ -14,18 +14,17 @@ concurrency: cancel-in-progress: true jobs: - test: - name: API Tests + api-unit: + name: API Unit Tests runs-on: ubuntu-latest env: - CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + COVERAGE_FILE: coverage-unit defaults: run: shell: bash strategy: matrix: python-version: - - "3.11" - "3.12" steps: @@ -51,6 +50,52 @@ jobs: - name: Run dify config tests run: uv run --project api dev/pytest/pytest_config_tests.py + - name: Run Unit Tests + run: uv run --project api bash dev/pytest/pytest_unit_tests.sh + + - name: Upload unit coverage data + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: api-coverage-unit + path: coverage-unit + retention-days: 1 + + api-integration: + name: API Integration Tests + runs-on: ubuntu-latest + env: + COVERAGE_FILE: coverage-integration + STORAGE_TYPE: opendal + OPENDAL_SCHEME: fs + OPENDAL_FS_ROOT: /tmp/dify-storage + defaults: + run: + shell: bash + strategy: + matrix: + python-version: + - "3.12" + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Setup UV and Python + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + with: + enable-cache: true + python-version: ${{ matrix.python-version }} + cache-dependency-glob: api/uv.lock + + - name: Check UV lockfile + run: uv lock --project api --check + + - name: Install dependencies + run: uv sync --project api --dev + - name: Set up dotenvs run: | cp docker/.env.example docker/.env @@ -74,22 +119,90 @@ jobs: run: | cp api/tests/integration_tests/.env.example api/tests/integration_tests/.env - - name: Run API Tests - env: - STORAGE_TYPE: opendal - OPENDAL_SCHEME: fs - OPENDAL_FS_ROOT: /tmp/dify-storage + - name: Run Integration Tests run: | uv run --project api pytest \ -n auto \ --timeout "${PYTEST_TIMEOUT:-180}" \ api/tests/integration_tests/workflow \ api/tests/integration_tests/tools \ - api/tests/test_containers_integration_tests \ - api/tests/unit_tests + api/tests/test_containers_integration_tests + + - name: Upload integration coverage data + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: api-coverage-integration + path: coverage-integration + retention-days: 1 + + api-coverage: + name: API Coverage + runs-on: ubuntu-latest + needs: + - api-unit + - api-integration + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} + COVERAGE_FILE: .coverage + defaults: + run: + shell: bash + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + fetch-depth: 0 + persist-credentials: false + + - name: Setup UV and Python + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + with: + enable-cache: true + python-version: "3.12" + cache-dependency-glob: api/uv.lock + + - name: Install dependencies + run: uv sync --project api --dev + + - name: Download coverage data + uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1 + with: + path: coverage-data + pattern: api-coverage-* + merge-multiple: true + + - name: Combine coverage + run: | + set -euo pipefail + + echo "### API Coverage" >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + echo "Merged backend coverage report generated for Codecov project status." >> "$GITHUB_STEP_SUMMARY" + echo "" >> "$GITHUB_STEP_SUMMARY" + + unit_coverage="$(find coverage-data -type f -name coverage-unit -print -quit)" + integration_coverage="$(find coverage-data -type f -name coverage-integration -print -quit)" + : "${unit_coverage:?coverage-unit artifact not found}" + : "${integration_coverage:?coverage-integration artifact not found}" + + report_file="$(mktemp)" + uv run --project api coverage combine "$unit_coverage" "$integration_coverage" + uv run --project api coverage report --show-missing | tee "$report_file" + echo "Summary: \`$(tail -n 1 "$report_file")\`" >> "$GITHUB_STEP_SUMMARY" + { + echo "" + echo "
Coverage report" + echo "" + echo '```' + cat "$report_file" + echo '```' + echo "
" + } >> "$GITHUB_STEP_SUMMARY" + uv run --project api coverage xml -o coverage.xml - name: Report coverage - if: ${{ env.CODECOV_TOKEN != '' && matrix.python-version == '3.12' }} + if: ${{ env.CODECOV_TOKEN != '' }} uses: codecov/codecov-action@1af58845a975a7985b0beb0cbe6fbbb71a41dbad # v5.5.3 with: files: ./coverage.xml diff --git a/.github/workflows/autofix.yml b/.github/workflows/autofix.yml index be6186980e..d8a53c9594 100644 --- a/.github/workflows/autofix.yml +++ b/.github/workflows/autofix.yml @@ -2,6 +2,9 @@ name: autofix.ci on: pull_request: branches: ["main"] + merge_group: + branches: ["main"] + types: [checks_requested] push: branches: ["main"] permissions: @@ -12,9 +15,15 @@ jobs: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest steps: - - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + - name: Complete merge group check + if: github.event_name == 'merge_group' + run: echo "autofix.ci updates pull request branches, not merge group refs." + + - if: github.event_name != 'merge_group' + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - name: Check Docker Compose inputs + if: github.event_name != 'merge_group' id: docker-compose-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: @@ -24,30 +33,34 @@ jobs: docker/docker-compose-template.yaml docker/docker-compose.yaml - name: Check web inputs + if: github.event_name != 'merge_group' id: web-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | web/** - name: Check api inputs + if: github.event_name != 'merge_group' id: api-changes uses: tj-actions/changed-files@22103cc46bda19c2b464ffe86db46df6922fd323 # v47.0.5 with: files: | api/** - - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 + - if: github.event_name != 'merge_group' + uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0 with: python-version: "3.11" - - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + - if: github.event_name != 'merge_group' + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 - name: Generate Docker Compose - if: steps.docker-compose-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.docker-compose-changes.outputs.any_changed == 'true' run: | cd docker ./generate_docker_compose - - if: steps.api-changes.outputs.any_changed == 'true' + - if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | cd api uv sync --dev @@ -59,13 +72,13 @@ jobs: uv run ruff format .. - name: count migration progress - if: steps.api-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | cd api ./cnt_base.sh - name: ast-grep - if: steps.api-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.api-changes.outputs.any_changed == 'true' run: | # ast-grep exits 1 if no matches are found; allow idempotent runs. uvx --from ast-grep-cli ast-grep --pattern 'db.session.query($WHATEVER).filter($HERE)' --rewrite 'db.session.query($WHATEVER).where($HERE)' -l py --update-all || true @@ -95,13 +108,14 @@ jobs: find . -name "*.py.bak" -type f -delete - name: Setup web environment - if: steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' uses: ./.github/actions/setup-web - name: ESLint autofix - if: steps.web-changes.outputs.any_changed == 'true' + if: github.event_name != 'merge_group' && steps.web-changes.outputs.any_changed == 'true' run: | cd web vp exec eslint --concurrency=2 --prune-suppressions --quiet || true - - uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 + - if: github.event_name != 'merge_group' + uses: autofix-ci/action@7a166d7532b277f34e16238930461bf77f9d7ed8 # v1.3.3 diff --git a/.github/workflows/main-ci.yml b/.github/workflows/main-ci.yml index 69023c24cc..2d96dae4da 100644 --- a/.github/workflows/main-ci.yml +++ b/.github/workflows/main-ci.yml @@ -3,10 +3,14 @@ name: Main CI Pipeline on: pull_request: branches: ["main"] + merge_group: + branches: ["main"] + types: [checks_requested] push: branches: ["main"] permissions: + actions: write contents: write pull-requests: write checks: write @@ -17,12 +21,28 @@ concurrency: cancel-in-progress: true jobs: + pre_job: + name: Skip Duplicate Checks + runs-on: ubuntu-latest + outputs: + should_skip: ${{ steps.skip_check.outputs.should_skip || 'false' }} + steps: + - id: skip_check + continue-on-error: true + uses: fkirc/skip-duplicate-actions@f75f66ce1886f00957d99748a42c724f4330bdcf # v5.3.1 + with: + cancel_others: 'true' + concurrent_skipping: same_content_newer + # Check which paths were changed to determine which tests to run check-changes: name: Check Changed Files + needs: pre_job + if: needs.pre_job.outputs.should_skip != 'true' runs-on: ubuntu-latest outputs: api-changed: ${{ steps.changes.outputs.api }} + e2e-changed: ${{ steps.changes.outputs.e2e }} web-changed: ${{ steps.changes.outputs.web }} vdb-changed: ${{ steps.changes.outputs.vdb }} migration-changed: ${{ steps.changes.outputs.migration }} @@ -34,49 +54,364 @@ jobs: filters: | api: - 'api/**' - - 'docker/**' - '.github/workflows/api-tests.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.middleware.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/ssrf_proxy/**' + - 'docker/volumes/sandbox/conf/**' web: - 'web/**' - '.github/workflows/web-tests.yml' - '.github/actions/setup-web/**' + e2e: + - 'api/**' + - 'api/pyproject.toml' + - 'api/uv.lock' + - 'e2e/**' + - 'web/**' + - 'docker/docker-compose.middleware.yaml' + - 'docker/middleware.env.example' + - '.github/workflows/web-e2e.yml' + - '.github/actions/setup-web/**' vdb: - 'api/core/rag/datasource/**' - - 'docker/**' + - 'api/tests/integration_tests/vdb/**' - '.github/workflows/vdb-tests.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/certbot/**' + - 'docker/couchbase-server/**' + - 'docker/elasticsearch/**' + - 'docker/iris/**' + - 'docker/nginx/**' + - 'docker/pgvector/**' + - 'docker/ssrf_proxy/**' + - 'docker/startupscripts/**' + - 'docker/tidb/**' + - 'docker/volumes/**' - 'api/uv.lock' - 'api/pyproject.toml' migration: - 'api/migrations/**' + - 'api/.env.example' - '.github/workflows/db-migration-test.yml' + - '.github/workflows/expose_service_ports.sh' + - 'docker/.env.example' + - 'docker/middleware.env.example' + - 'docker/docker-compose.middleware.yaml' + - 'docker/docker-compose-template.yaml' + - 'docker/generate_docker_compose' + - 'docker/ssrf_proxy/**' + - 'docker/volumes/sandbox/conf/**' - # Run tests in parallel - api-tests: - name: API Tests - needs: check-changes - if: needs.check-changes.outputs.api-changed == 'true' + # Run tests in parallel while always emitting stable required checks. + api-tests-run: + name: Run API Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed == 'true' uses: ./.github/workflows/api-tests.yml secrets: inherit - web-tests: - name: Web Tests - needs: check-changes - if: needs.check-changes.outputs.web-changed == 'true' + api-tests-skip: + name: Skip API Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.api-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped API tests + run: echo "No API-related changes detected; skipping API tests." + + api-tests: + name: API Tests + if: ${{ always() }} + needs: + - pre_job + - check-changes + - api-tests-run + - api-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize API Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.api-changed }} + RUN_RESULT: ${{ needs.api-tests-run.result }} + SKIP_RESULT: ${{ needs.api-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "API tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "API tests ran successfully." + exit 0 + fi + + echo "API tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "API tests were skipped because no API-related files changed." + exit 0 + fi + + echo "API tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + web-tests-run: + name: Run Web Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed == 'true' uses: ./.github/workflows/web-tests.yml secrets: inherit + web-tests-skip: + name: Skip Web Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.web-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped web tests + run: echo "No web-related changes detected; skipping web tests." + + web-tests: + name: Web Tests + if: ${{ always() }} + needs: + - pre_job + - check-changes + - web-tests-run + - web-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize Web Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.web-changed }} + RUN_RESULT: ${{ needs.web-tests-run.result }} + SKIP_RESULT: ${{ needs.web-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "Web tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "Web tests ran successfully." + exit 0 + fi + + echo "Web tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "Web tests were skipped because no web-related files changed." + exit 0 + fi + + echo "Web tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + web-e2e-run: + name: Run Web Full-Stack E2E + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed == 'true' + uses: ./.github/workflows/web-e2e.yml + + web-e2e-skip: + name: Skip Web Full-Stack E2E + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.e2e-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped web full-stack e2e + run: echo "No E2E-related changes detected; skipping web full-stack E2E." + + web-e2e: + name: Web Full-Stack E2E + if: ${{ always() }} + needs: + - pre_job + - check-changes + - web-e2e-run + - web-e2e-skip + runs-on: ubuntu-latest + steps: + - name: Finalize Web Full-Stack E2E status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.e2e-changed }} + RUN_RESULT: ${{ needs.web-e2e-run.result }} + SKIP_RESULT: ${{ needs.web-e2e-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "Web full-stack E2E was skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "Web full-stack E2E ran successfully." + exit 0 + fi + + echo "Web full-stack E2E was required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "Web full-stack E2E was skipped because no E2E-related files changed." + exit 0 + fi + + echo "Web full-stack E2E was not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + style-check: name: Style Check + needs: pre_job + if: needs.pre_job.outputs.should_skip != 'true' uses: ./.github/workflows/style.yml + vdb-tests-run: + name: Run VDB Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed == 'true' + uses: ./.github/workflows/vdb-tests.yml + + vdb-tests-skip: + name: Skip VDB Tests + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.vdb-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped VDB tests + run: echo "No VDB-related changes detected; skipping VDB tests." + vdb-tests: name: VDB Tests - needs: check-changes - if: needs.check-changes.outputs.vdb-changed == 'true' - uses: ./.github/workflows/vdb-tests.yml + if: ${{ always() }} + needs: + - pre_job + - check-changes + - vdb-tests-run + - vdb-tests-skip + runs-on: ubuntu-latest + steps: + - name: Finalize VDB Tests status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.vdb-changed }} + RUN_RESULT: ${{ needs.vdb-tests-run.result }} + SKIP_RESULT: ${{ needs.vdb-tests-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "VDB tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "VDB tests ran successfully." + exit 0 + fi + + echo "VDB tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "VDB tests were skipped because no VDB-related files changed." + exit 0 + fi + + echo "VDB tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 + + db-migration-test-run: + name: Run DB Migration Test + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed == 'true' + uses: ./.github/workflows/db-migration-test.yml + + db-migration-test-skip: + name: Skip DB Migration Test + needs: + - pre_job + - check-changes + if: needs.pre_job.outputs.should_skip != 'true' && needs.check-changes.outputs.migration-changed != 'true' + runs-on: ubuntu-latest + steps: + - name: Report skipped DB migration tests + run: echo "No migration-related changes detected; skipping DB migration tests." db-migration-test: name: DB Migration Test - needs: check-changes - if: needs.check-changes.outputs.migration-changed == 'true' - uses: ./.github/workflows/db-migration-test.yml + if: ${{ always() }} + needs: + - pre_job + - check-changes + - db-migration-test-run + - db-migration-test-skip + runs-on: ubuntu-latest + steps: + - name: Finalize DB Migration Test status + env: + SHOULD_SKIP_WORKFLOW: ${{ needs.pre_job.outputs.should_skip }} + TESTS_CHANGED: ${{ needs.check-changes.outputs.migration-changed }} + RUN_RESULT: ${{ needs.db-migration-test-run.result }} + SKIP_RESULT: ${{ needs.db-migration-test-skip.result }} + run: | + if [[ "$SHOULD_SKIP_WORKFLOW" == 'true' ]]; then + echo "DB migration tests were skipped because this workflow run duplicated a successful or newer run." + exit 0 + fi + + if [[ "$TESTS_CHANGED" == 'true' ]]; then + if [[ "$RUN_RESULT" == 'success' ]]; then + echo "DB migration tests ran successfully." + exit 0 + fi + + echo "DB migration tests were required but finished with result: $RUN_RESULT" >&2 + exit 1 + fi + + if [[ "$SKIP_RESULT" == 'success' ]]; then + echo "DB migration tests were skipped because no migration-related files changed." + exit 0 + fi + + echo "DB migration tests were not required, but the skip job finished with result: $SKIP_RESULT" >&2 + exit 1 diff --git a/.github/workflows/semantic-pull-request.yml b/.github/workflows/semantic-pull-request.yml index c21331ec0d..49d2e94695 100644 --- a/.github/workflows/semantic-pull-request.yml +++ b/.github/workflows/semantic-pull-request.yml @@ -7,6 +7,9 @@ on: - edited - reopened - synchronize + merge_group: + branches: ["main"] + types: [checks_requested] jobs: lint: @@ -15,7 +18,11 @@ jobs: pull-requests: read runs-on: ubuntu-latest steps: + - name: Complete merge group check + if: github.event_name == 'merge_group' + run: echo "Semantic PR title validation is handled on pull requests." - name: Check title + if: github.event_name == 'pull_request' uses: amannn/action-semantic-pull-request@48f256284bd46cdaab1048c3721360e808335d50 # v6.1.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml index 23ae36f7b1..7b269ccf4e 100644 --- a/.github/workflows/style.yml +++ b/.github/workflows/style.yml @@ -49,7 +49,7 @@ jobs: - name: Run Type Checks if: steps.changed-files.outputs.any_changed == 'true' - run: make type-check + run: make type-check-core - name: Dotenv check if: steps.changed-files.outputs.any_changed == 'true' diff --git a/.github/workflows/translate-i18n-claude.yml b/.github/workflows/translate-i18n-claude.yml index 1869254295..aaf51aa606 100644 --- a/.github/workflows/translate-i18n-claude.yml +++ b/.github/workflows/translate-i18n-claude.yml @@ -1,26 +1,24 @@ name: Translate i18n Files with Claude Code -# Note: claude-code-action doesn't support push events directly. -# Push events are handled by trigger-i18n-sync.yml which sends repository_dispatch. -# See: https://github.com/langgenius/dify/issues/30743 - on: - repository_dispatch: - types: [i18n-sync] + push: + branches: [main] + paths: + - 'web/i18n/en-US/*.json' workflow_dispatch: inputs: files: - description: 'Specific files to translate (space-separated, e.g., "app common"). Leave empty for all files.' + description: 'Specific files to translate (space-separated, e.g., "app common"). Required for full mode; leave empty in incremental mode to use en-US files changed since HEAD~1.' required: false type: string languages: - description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported languages.' + description: 'Specific languages to translate (space-separated, e.g., "zh-Hans ja-JP"). Leave empty for all supported target languages except en-US.' required: false type: string mode: - description: 'Sync mode: incremental (only changes) or full (re-check all keys)' + description: 'Sync mode: incremental (compare with previous en-US revision) or full (sync all keys in scope)' required: false - default: 'incremental' + default: incremental type: choice options: - incremental @@ -30,11 +28,15 @@ permissions: contents: write pull-requests: write +concurrency: + group: translate-i18n-${{ github.event_name }}-${{ github.ref }} + cancel-in-progress: ${{ github.event_name == 'push' }} + jobs: translate: if: github.repository == 'langgenius/dify' runs-on: ubuntu-latest - timeout-minutes: 60 + timeout-minutes: 120 steps: - name: Checkout repository @@ -51,380 +53,161 @@ jobs: - name: Setup web environment uses: ./.github/actions/setup-web - - name: Detect changed files and generate diff - id: detect_changes + - name: Prepare sync context + id: context + shell: bash run: | - if [ "${{ github.event_name }}" == "workflow_dispatch" ]; then - # Manual trigger - if [ -n "${{ github.event.inputs.files }}" ]; then - echo "CHANGED_FILES=${{ github.event.inputs.files }}" >> $GITHUB_OUTPUT - else - # Get all JSON files in en-US directory - files=$(ls web/i18n/en-US/*.json 2>/dev/null | xargs -n1 basename | sed 's/.json$//' | tr '\n' ' ') - echo "CHANGED_FILES=$files" >> $GITHUB_OUTPUT - fi - echo "TARGET_LANGS=${{ github.event.inputs.languages }}" >> $GITHUB_OUTPUT - echo "SYNC_MODE=${{ github.event.inputs.mode || 'incremental' }}" >> $GITHUB_OUTPUT + DEFAULT_TARGET_LANGS=$(awk " + /value: '/ { + value=\$2 + gsub(/[',]/, \"\", value) + } + /supported: true/ && value != \"en-US\" { + printf \"%s \", value + } + " web/i18n-config/languages.ts | sed 's/[[:space:]]*$//') - # For manual trigger with incremental mode, get diff from last commit - # For full mode, we'll do a complete check anyway - if [ "${{ github.event.inputs.mode }}" == "full" ]; then - echo "Full mode: will check all keys" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + if [ "${{ github.event_name }}" = "push" ]; then + BASE_SHA="${{ github.event.before }}" + if [ -z "$BASE_SHA" ] || [ "$BASE_SHA" = "0000000000000000000000000000000000000000" ]; then + BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true) + fi + HEAD_SHA="${{ github.sha }}" + if [ -n "$BASE_SHA" ]; then + CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//') else - git diff HEAD~1..HEAD -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt - if [ -s /tmp/i18n-diff.txt ]; then - echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT - else - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - fi - fi - elif [ "${{ github.event_name }}" == "repository_dispatch" ]; then - # Triggered by push via trigger-i18n-sync.yml workflow - # Validate required payload fields - if [ -z "${{ github.event.client_payload.changed_files }}" ]; then - echo "Error: repository_dispatch payload missing required 'changed_files' field" >&2 - exit 1 - fi - echo "CHANGED_FILES=${{ github.event.client_payload.changed_files }}" >> $GITHUB_OUTPUT - echo "TARGET_LANGS=" >> $GITHUB_OUTPUT - echo "SYNC_MODE=${{ github.event.client_payload.sync_mode || 'incremental' }}" >> $GITHUB_OUTPUT - - # Decode the base64-encoded diff from the trigger workflow - if [ -n "${{ github.event.client_payload.diff_base64 }}" ]; then - if ! echo "${{ github.event.client_payload.diff_base64 }}" | base64 -d > /tmp/i18n-diff.txt 2>&1; then - echo "Warning: Failed to decode base64 diff payload" >&2 - echo "" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - elif [ -s /tmp/i18n-diff.txt ]; then - echo "DIFF_AVAILABLE=true" >> $GITHUB_OUTPUT - else - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT - fi - else - echo "" > /tmp/i18n-diff.txt - echo "DIFF_AVAILABLE=false" >> $GITHUB_OUTPUT + CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//') fi + TARGET_LANGS="$DEFAULT_TARGET_LANGS" + SYNC_MODE="incremental" else - echo "Unsupported event type: ${{ github.event_name }}" - exit 1 + BASE_SHA="" + HEAD_SHA=$(git rev-parse HEAD) + if [ -n "${{ github.event.inputs.languages }}" ]; then + TARGET_LANGS="${{ github.event.inputs.languages }}" + else + TARGET_LANGS="$DEFAULT_TARGET_LANGS" + fi + SYNC_MODE="${{ github.event.inputs.mode || 'incremental' }}" + if [ -n "${{ github.event.inputs.files }}" ]; then + CHANGED_FILES="${{ github.event.inputs.files }}" + elif [ "$SYNC_MODE" = "incremental" ]; then + BASE_SHA=$(git rev-parse HEAD~1 2>/dev/null || true) + if [ -n "$BASE_SHA" ]; then + CHANGED_FILES=$(git diff --name-only "$BASE_SHA" "$HEAD_SHA" -- 'web/i18n/en-US/*.json' 2>/dev/null | sed -n 's@^.*/@@p' | sed 's/\.json$//' | tr '\n' ' ' | sed 's/[[:space:]]*$//') + else + CHANGED_FILES=$(find web/i18n/en-US -maxdepth 1 -type f -name '*.json' -print | sed -n 's@^.*/@@p' | sed 's/\.json$//' | sort | tr '\n' ' ' | sed 's/[[:space:]]*$//') + fi + elif [ "$SYNC_MODE" = "full" ]; then + echo "workflow_dispatch full mode requires the files input to stay within CI limits." >&2 + exit 1 + else + CHANGED_FILES="" + fi fi - # Truncate diff if too large (keep first 50KB) - if [ -f /tmp/i18n-diff.txt ]; then - head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt - mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt + FILE_ARGS="" + if [ -n "$CHANGED_FILES" ]; then + FILE_ARGS="--file $CHANGED_FILES" fi - echo "Detected files: $(cat $GITHUB_OUTPUT | grep CHANGED_FILES || echo 'none')" + LANG_ARGS="" + if [ -n "$TARGET_LANGS" ]; then + LANG_ARGS="--lang $TARGET_LANGS" + fi + + { + echo "DEFAULT_TARGET_LANGS=$DEFAULT_TARGET_LANGS" + echo "BASE_SHA=$BASE_SHA" + echo "HEAD_SHA=$HEAD_SHA" + echo "CHANGED_FILES=$CHANGED_FILES" + echo "TARGET_LANGS=$TARGET_LANGS" + echo "SYNC_MODE=$SYNC_MODE" + echo "FILE_ARGS=$FILE_ARGS" + echo "LANG_ARGS=$LANG_ARGS" + } >> "$GITHUB_OUTPUT" + + echo "Files: ${CHANGED_FILES:-}" + echo "Languages: ${TARGET_LANGS:-}" + echo "Mode: $SYNC_MODE" - name: Run Claude Code for Translation Sync - if: steps.detect_changes.outputs.CHANGED_FILES != '' - uses: anthropics/claude-code-action@ff9acae5886d41a99ed4ec14b7dc147d55834722 # v1.0.77 + if: steps.context.outputs.CHANGED_FILES != '' + uses: anthropics/claude-code-action@88c168b39e7e64da0286d812b6e9fbebb6708185 # v1.0.82 with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} github_token: ${{ secrets.GITHUB_TOKEN }} - # Allow github-actions bot to trigger this workflow via repository_dispatch - # See: https://github.com/anthropics/claude-code-action/blob/main/docs/usage.md allowed_bots: 'github-actions[bot]' + show_full_output: ${{ github.event_name == 'workflow_dispatch' }} prompt: | - You are a professional i18n synchronization engineer for the Dify project. - Your task is to keep all language translations in sync with the English source (en-US). + You are the i18n sync agent for the Dify repository. + Your job is to keep translations synchronized with the English source files under `${{ github.workspace }}/web/i18n/en-US/`, then open a PR with the result. - ## CRITICAL TOOL RESTRICTIONS - - Use **Read** tool to read files (NOT cat or bash) - - Use **Edit** tool to modify JSON files (NOT node, jq, or bash scripts) - - Use **Bash** ONLY for: git commands, gh commands, pnpm commands - - Run bash commands ONE BY ONE, never combine with && or || - - NEVER use `$()` command substitution - it's not supported. Split into separate commands instead. + Use absolute paths at all times: + - Repo root: `${{ github.workspace }}` + - Web directory: `${{ github.workspace }}/web` + - Language config: `${{ github.workspace }}/web/i18n-config/languages.ts` - ## WORKING DIRECTORY & ABSOLUTE PATHS - Claude Code sandbox working directory may vary. Always use absolute paths: - - For pnpm: `pnpm --dir ${{ github.workspace }}/web ` - - For git: `git -C ${{ github.workspace }} ` - - For gh: `gh --repo ${{ github.repository }} ` - - For file paths: `${{ github.workspace }}/web/i18n/` + Inputs: + - Files in scope: `${{ steps.context.outputs.CHANGED_FILES }}` + - Target languages: `${{ steps.context.outputs.TARGET_LANGS }}` + - Sync mode: `${{ steps.context.outputs.SYNC_MODE }}` + - Base SHA: `${{ steps.context.outputs.BASE_SHA }}` + - Head SHA: `${{ steps.context.outputs.HEAD_SHA }}` + - Scoped file args: `${{ steps.context.outputs.FILE_ARGS }}` + - Scoped language args: `${{ steps.context.outputs.LANG_ARGS }}` - ## EFFICIENCY RULES - - **ONE Edit per language file** - batch all key additions into a single Edit - - Insert new keys at the beginning of JSON (after `{`), lint:fix will sort them - - Translate ALL keys for a language mentally first, then do ONE Edit - - ## Context - - Changed/target files: ${{ steps.detect_changes.outputs.CHANGED_FILES }} - - Target languages (empty means all supported): ${{ steps.detect_changes.outputs.TARGET_LANGS }} - - Sync mode: ${{ steps.detect_changes.outputs.SYNC_MODE }} - - Translation files are located in: ${{ github.workspace }}/web/i18n/{locale}/{filename}.json - - Language configuration is in: ${{ github.workspace }}/web/i18n-config/languages.ts - - Git diff is available: ${{ steps.detect_changes.outputs.DIFF_AVAILABLE }} - - ## CRITICAL DESIGN: Verify First, Then Sync - - You MUST follow this three-phase approach: - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 1: VERIFY - Analyze and Generate Change Report ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 1.1: Analyze Git Diff (for incremental mode) - Use the Read tool to read `/tmp/i18n-diff.txt` to see the git diff. - - Parse the diff to categorize changes: - - Lines with `+` (not `+++`): Added or modified values - - Lines with `-` (not `---`): Removed or old values - - Identify specific keys for each category: - * ADD: Keys that appear only in `+` lines (new keys) - * UPDATE: Keys that appear in both `-` and `+` lines (value changed) - * DELETE: Keys that appear only in `-` lines (removed keys) - - ### Step 1.2: Read Language Configuration - Use the Read tool to read `${{ github.workspace }}/web/i18n-config/languages.ts`. - Extract all languages with `supported: true`. - - ### Step 1.3: Run i18n:check for Each Language - ```bash - pnpm --dir ${{ github.workspace }}/web install --frozen-lockfile - ``` - ```bash - pnpm --dir ${{ github.workspace }}/web run i18n:check - ``` - - This will report: - - Missing keys (need to ADD) - - Extra keys (need to DELETE) - - ### Step 1.4: Generate Change Report - - Create a structured report identifying: - ``` - ╔══════════════════════════════════════════════════════════════╗ - ║ I18N SYNC CHANGE REPORT ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ Files to process: [list] ║ - ║ Languages to sync: [list] ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ ADD (New Keys): ║ - ║ - [filename].[key]: "English value" ║ - ║ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ UPDATE (Modified Keys - MUST re-translate): ║ - ║ - [filename].[key]: "Old value" → "New value" ║ - ║ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ DELETE (Extra Keys): ║ - ║ - [language]/[filename].[key] ║ - ║ ... ║ - ╚══════════════════════════════════════════════════════════════╝ - ``` - - **IMPORTANT**: For UPDATE detection, compare git diff to find keys where - the English value changed. These MUST be re-translated even if target - language already has a translation (it's now stale!). - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 2: SYNC - Execute Changes Based on Report ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 2.1: Process ADD Operations (BATCH per language file) - - **CRITICAL WORKFLOW for efficiency:** - 1. First, translate ALL new keys for ALL languages mentally - 2. Then, for EACH language file, do ONE Edit operation: - - Read the file once - - Insert ALL new keys at the beginning (right after the opening `{`) - - Don't worry about alphabetical order - lint:fix will sort them later - - Example Edit (adding 3 keys to zh-Hans/app.json): - ``` - old_string: '{\n "accessControl"' - new_string: '{\n "newKey1": "translation1",\n "newKey2": "translation2",\n "newKey3": "translation3",\n "accessControl"' - ``` - - **IMPORTANT**: - - ONE Edit per language file (not one Edit per key!) - - Always use the Edit tool. NEVER use bash scripts, node, or jq. - - ### Step 2.2: Process UPDATE Operations - - **IMPORTANT: Special handling for zh-Hans and ja-JP** - If zh-Hans or ja-JP files were ALSO modified in the same push: - - Run: `git -C ${{ github.workspace }} diff HEAD~1 --name-only` and check for zh-Hans or ja-JP files - - If found, it means someone manually translated them. Apply these rules: - - 1. **Missing keys**: Still ADD them (completeness required) - 2. **Existing translations**: Compare with the NEW English value: - - If translation is **completely wrong** or **unrelated** → Update it - - If translation is **roughly correct** (captures the meaning) → Keep it, respect manual work - - When in doubt, **keep the manual translation** - - Example: - - English changed: "Save" → "Save Changes" - - Manual translation: "保存更改" → Keep it (correct meaning) - - Manual translation: "删除" → Update it (completely wrong) - - For other languages: - Use Edit tool to replace the old value with the new translation. - You can batch multiple updates in one Edit if they are adjacent. - - ### Step 2.3: Process DELETE Operations - For extra keys reported by i18n:check: - - Run: `pnpm --dir ${{ github.workspace }}/web run i18n:check --auto-remove` - - Or manually remove from target language JSON files - - ## Translation Guidelines - - - PRESERVE all placeholders exactly as-is: - - `{{variable}}` - Mustache interpolation - - `${variable}` - Template literal - - `content` - HTML tags - - `_one`, `_other` - Pluralization suffixes (these are KEY suffixes, not values) - - **CRITICAL: Variable names and tag names MUST stay in English - NEVER translate them** - - ✅ CORRECT examples: - - English: "{{count}} items" → Japanese: "{{count}} 個のアイテム" - - English: "{{name}} updated" → Korean: "{{name}} 업데이트됨" - - English: "{{email}}" → Chinese: "{{email}}" - - English: "Marketplace" → Japanese: "マーケットプレイス" - - ❌ WRONG examples (NEVER do this - will break the application): - - "{{count}}" → "{{カウント}}" ❌ (variable name translated to Japanese) - - "{{name}}" → "{{이름}}" ❌ (variable name translated to Korean) - - "{{email}}" → "{{邮箱}}" ❌ (variable name translated to Chinese) - - "" → "<メール>" ❌ (tag name translated) - - "" → "<自定义链接>" ❌ (component name translated) - - - Use appropriate language register (formal/informal) based on existing translations - - Match existing translation style in each language - - Technical terms: check existing conventions per language - - For CJK languages: no spaces between characters unless necessary - - For RTL languages (ar-TN, fa-IR): ensure proper text handling - - ## Output Format Requirements - - Alphabetical key ordering (if original file uses it) - - 2-space indentation - - Trailing newline at end of file - - Valid JSON (use proper escaping for special characters) - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 3: RE-VERIFY - Confirm All Issues Resolved ║ - ═══════════════════════════════════════════════════════════════ - - ### Step 3.1: Run Lint Fix (IMPORTANT!) - ```bash - pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- 'i18n/**/*.json' - ``` - This ensures: - - JSON keys are sorted alphabetically (jsonc/sort-keys rule) - - Valid i18n keys (dify-i18n/valid-i18n-keys rule) - - No extra keys (dify-i18n/no-extra-keys rule) - - ### Step 3.2: Run Final i18n Check - ```bash - pnpm --dir ${{ github.workspace }}/web run i18n:check - ``` - - ### Step 3.3: Fix Any Remaining Issues - If check reports issues: - - Go back to PHASE 2 for unresolved items - - Repeat until check passes - - ### Step 3.4: Generate Final Summary - ``` - ╔══════════════════════════════════════════════════════════════╗ - ║ SYNC COMPLETED SUMMARY ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ Language │ Added │ Updated │ Deleted │ Status ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ zh-Hans │ 5 │ 2 │ 1 │ ✓ Complete ║ - ║ ja-JP │ 5 │ 2 │ 1 │ ✓ Complete ║ - ║ ... │ ... │ ... │ ... │ ... ║ - ╠══════════════════════════════════════════════════════════════╣ - ║ i18n:check │ PASSED - All keys in sync ║ - ╚══════════════════════════════════════════════════════════════╝ - ``` - - ## Mode-Specific Behavior - - **SYNC_MODE = "incremental"** (default): - - Focus on keys identified from git diff - - Also check i18n:check output for any missing/extra keys - - Efficient for small changes - - **SYNC_MODE = "full"**: - - Compare ALL keys between en-US and each language - - Run i18n:check to identify all discrepancies - - Use for first-time sync or fixing historical issues - - ## Important Notes - - 1. Always run i18n:check BEFORE and AFTER making changes - 2. The check script is the source of truth for missing/extra keys - 3. For UPDATE scenario: git diff is the source of truth for changed values - 4. Create a single commit with all translation changes - 5. If any translation fails, continue with others and report failures - - ═══════════════════════════════════════════════════════════════ - ║ PHASE 4: COMMIT AND CREATE PR ║ - ═══════════════════════════════════════════════════════════════ - - After all translations are complete and verified: - - ### Step 4.1: Check for changes - ```bash - git -C ${{ github.workspace }} status --porcelain - ``` - - If there are changes: - - ### Step 4.2: Create a new branch and commit - Run these git commands ONE BY ONE (not combined with &&). - **IMPORTANT**: Do NOT use `$()` command substitution. Use two separate commands: - - 1. First, get the timestamp: - ```bash - date +%Y%m%d-%H%M%S - ``` - (Note the output, e.g., "20260115-143052") - - 2. Then create branch using the timestamp value: - ```bash - git -C ${{ github.workspace }} checkout -b chore/i18n-sync-20260115-143052 - ``` - (Replace "20260115-143052" with the actual timestamp from step 1) - - 3. Stage changes: - ```bash - git -C ${{ github.workspace }} add web/i18n/ - ``` - - 4. Commit: - ```bash - git -C ${{ github.workspace }} commit -m "chore(i18n): sync translations with en-US - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }}" - ``` - - 5. Push: - ```bash - git -C ${{ github.workspace }} push origin HEAD - ``` - - ### Step 4.3: Create Pull Request - ```bash - gh pr create --repo ${{ github.repository }} --title "chore(i18n): sync translations with en-US" --body "## Summary - - This PR was automatically generated to sync i18n translation files. - - ### Changes - - Mode: ${{ steps.detect_changes.outputs.SYNC_MODE }} - - Files processed: ${{ steps.detect_changes.outputs.CHANGED_FILES }} - - ### Verification - - [x] \`i18n:check\` passed - - [x] \`lint:fix\` applied - - 🤖 Generated with Claude Code GitHub Action" --base main - ``` + Tool rules: + - Use Read for repository files. + - Use Edit for JSON updates. + - Use Bash only for `git`, `gh`, `pnpm`, and `date`. + - Run Bash commands one by one. Do not combine commands with `&&`, `||`, pipes, or command substitution. + Required execution plan: + 1. Resolve target languages. + - Use the provided `Target languages` value as the source of truth. + - If it is unexpectedly empty, read `${{ github.workspace }}/web/i18n-config/languages.ts` and use every language with `supported: true` except `en-US`. + 2. Stay strictly in scope. + - Only process the files listed in `Files in scope`. + - Only process the resolved target languages, never `en-US`. + - Do not touch unrelated i18n files. + - Do not modify `${{ github.workspace }}/web/i18n/en-US/`. + 3. Detect English changes per file. + - Read the current English JSON file for each file in scope. + - If sync mode is `incremental` and `Base SHA` is not empty, run: + `git -C ${{ github.workspace }} show :web/i18n/en-US/.json` + - If sync mode is `full` or `Base SHA` is empty, skip historical comparison and treat the current English file as the only source of truth for structural sync. + - If the file did not exist at Base SHA, treat all current keys as ADD. + - Compare previous and current English JSON to identify: + - ADD: key only in current + - UPDATE: key exists in both and the English value changed + - DELETE: key only in previous + - Do not rely on a truncated diff file. + 4. Run a scoped pre-check before editing: + - `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - Use this command as the source of truth for missing and extra keys inside the current scope. + 5. Apply translations. + - For every target language and scoped file: + - If the locale file does not exist yet, create it with `Write` and then continue with `Edit` as needed. + - ADD missing keys. + - UPDATE stale translations when the English value changed. + - DELETE removed keys. Prefer `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }} --auto-remove` for extra keys so deletions stay in scope. + - For `zh-Hans` and `ja-JP`, if the locale file also changed between Base SHA and Head SHA, preserve manual translations unless they are clearly wrong for the new English value. If in doubt, keep the manual translation. + - Preserve placeholders exactly: `{{variable}}`, `${variable}`, HTML tags, component tags, and variable names. + - Match the existing terminology and register used by each locale. + - Prefer one Edit per file when stable, but prioritize correctness over batching. + 6. Verify only the edited files. + - Run `pnpm --dir ${{ github.workspace }}/web lint:fix --quiet -- ` + - Run `pnpm --dir ${{ github.workspace }}/web run i18n:check ${{ steps.context.outputs.FILE_ARGS }} ${{ steps.context.outputs.LANG_ARGS }}` + - If verification fails, fix the remaining problems before continuing. + 7. Create a PR only when there are changes in `web/i18n/`. + - Check `git -C ${{ github.workspace }} status --porcelain -- web/i18n/` + - Create branch `chore/i18n-sync-` + - Commit message: `chore(i18n): sync translations with en-US` + - Push the branch and open a PR against `main` + - PR title: `chore(i18n): sync translations with en-US` + - PR body: summarize files, languages, sync mode, and verification commands + 8. If there are no translation changes after verification, do not create a branch, commit, or PR. claude_args: | - --max-turns 150 + --max-turns 80 --allowedTools "Read,Write,Edit,Bash(git *),Bash(git:*),Bash(gh *),Bash(gh:*),Bash(pnpm *),Bash(pnpm:*),Bash(date *),Bash(date:*),Glob,Grep" diff --git a/.github/workflows/trigger-i18n-sync.yml b/.github/workflows/trigger-i18n-sync.yml deleted file mode 100644 index 1caaddd47a..0000000000 --- a/.github/workflows/trigger-i18n-sync.yml +++ /dev/null @@ -1,66 +0,0 @@ -name: Trigger i18n Sync on Push - -# This workflow bridges the push event to repository_dispatch -# because claude-code-action doesn't support push events directly. -# See: https://github.com/langgenius/dify/issues/30743 - -on: - push: - branches: [main] - paths: - - 'web/i18n/en-US/*.json' - -permissions: - contents: write - -jobs: - trigger: - if: github.repository == 'langgenius/dify' - runs-on: ubuntu-latest - timeout-minutes: 5 - - steps: - - name: Checkout repository - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 - with: - fetch-depth: 0 - - - name: Detect changed files and generate diff - id: detect - run: | - BEFORE_SHA="${{ github.event.before }}" - # Handle edge case: force push may have null/zero SHA - if [ -z "$BEFORE_SHA" ] || [ "$BEFORE_SHA" = "0000000000000000000000000000000000000000" ]; then - BEFORE_SHA="HEAD~1" - fi - - # Detect changed i18n files - changed=$(git diff --name-only "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' 2>/dev/null | xargs -n1 basename 2>/dev/null | sed 's/.json$//' | tr '\n' ' ' || echo "") - echo "changed_files=$changed" >> $GITHUB_OUTPUT - - # Generate diff for context - git diff "$BEFORE_SHA" "${{ github.sha }}" -- 'web/i18n/en-US/*.json' > /tmp/i18n-diff.txt 2>/dev/null || echo "" > /tmp/i18n-diff.txt - - # Truncate if too large (keep first 50KB to match receiving workflow) - head -c 50000 /tmp/i18n-diff.txt > /tmp/i18n-diff-truncated.txt - mv /tmp/i18n-diff-truncated.txt /tmp/i18n-diff.txt - - # Base64 encode the diff for safe JSON transport (portable, single-line) - diff_base64=$(base64 < /tmp/i18n-diff.txt | tr -d '\n') - echo "diff_base64=$diff_base64" >> $GITHUB_OUTPUT - - if [ -n "$changed" ]; then - echo "has_changes=true" >> $GITHUB_OUTPUT - echo "Detected changed files: $changed" - else - echo "has_changes=false" >> $GITHUB_OUTPUT - echo "No i18n changes detected" - fi - - - name: Trigger i18n sync workflow - if: steps.detect.outputs.has_changes == 'true' - uses: peter-evans/repository-dispatch@28959ce8df70de7be546dd1250a005dd32156697 # v4.0.1 - with: - token: ${{ secrets.GITHUB_TOKEN }} - event-type: i18n-sync - client-payload: '{"changed_files": "${{ steps.detect.outputs.changed_files }}", "diff_base64": "${{ steps.detect.outputs.diff_base64 }}", "sync_mode": "incremental", "trigger_sha": "${{ github.sha }}"}' diff --git a/.github/workflows/vdb-tests.yml b/.github/workflows/vdb-tests.yml index f45f2137d6..7c4cd0ba8c 100644 --- a/.github/workflows/vdb-tests.yml +++ b/.github/workflows/vdb-tests.yml @@ -14,7 +14,6 @@ jobs: strategy: matrix: python-version: - - "3.11" - "3.12" steps: diff --git a/.github/workflows/web-e2e.yml b/.github/workflows/web-e2e.yml new file mode 100644 index 0000000000..8035d1ef8e --- /dev/null +++ b/.github/workflows/web-e2e.yml @@ -0,0 +1,72 @@ +name: Web Full-Stack E2E + +on: + workflow_call: + +permissions: + contents: read + +concurrency: + group: web-e2e-${{ github.head_ref || github.run_id }} + cancel-in-progress: true + +jobs: + test: + name: Web Full-Stack E2E + runs-on: ubuntu-latest + defaults: + run: + shell: bash + + steps: + - name: Checkout code + uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 + with: + persist-credentials: false + + - name: Setup web dependencies + uses: ./.github/actions/setup-web + + - name: Install E2E package dependencies + working-directory: ./e2e + run: vp install --frozen-lockfile + + - name: Setup UV and Python + uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # v7.6.0 + with: + enable-cache: true + python-version: "3.12" + cache-dependency-glob: api/uv.lock + + - name: Install API dependencies + run: uv sync --project api --dev + + - name: Install Playwright browser + working-directory: ./e2e + run: vp run e2e:install + + - name: Run isolated source-api and built-web Cucumber E2E tests + working-directory: ./e2e + env: + E2E_ADMIN_EMAIL: e2e-admin@example.com + E2E_ADMIN_NAME: E2E Admin + E2E_ADMIN_PASSWORD: E2eAdmin12345 + E2E_FORCE_WEB_BUILD: "1" + E2E_INIT_PASSWORD: E2eInit12345 + run: vp run e2e:full + + - name: Upload Cucumber report + if: ${{ !cancelled() }} + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: cucumber-report + path: e2e/cucumber-report + retention-days: 7 + + - name: Upload E2E logs + if: ${{ !cancelled() }} + uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f # v7.0.0 + with: + name: e2e-logs + path: e2e/.logs + retention-days: 7 diff --git a/.github/workflows/web-tests.yml b/.github/workflows/web-tests.yml index d40cd4bfeb..8110a16355 100644 --- a/.github/workflows/web-tests.yml +++ b/.github/workflows/web-tests.yml @@ -22,8 +22,8 @@ jobs: strategy: fail-fast: false matrix: - shardIndex: [1, 2, 3, 4, 5, 6] - shardTotal: [6] + shardIndex: [1, 2, 3, 4] + shardTotal: [4] defaults: run: shell: bash @@ -66,7 +66,6 @@ jobs: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2 with: - fetch-depth: 0 persist-credentials: false - name: Setup web environment diff --git a/Makefile b/Makefile index 55871c86a7..c377b7c671 100644 --- a/Makefile +++ b/Makefile @@ -74,6 +74,12 @@ type-check: @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . @echo "✅ Type checks complete" +type-check-core: + @echo "📝 Running core type checks (basedpyright + mypy)..." + @./dev/basedpyright-check $(PATH_TO_CHECK) + @uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped . + @echo "✅ Core type checks complete" + test: @echo "🧪 Running backend unit tests..." @if [ -n "$(TARGET_TESTS)" ]; then \ @@ -133,6 +139,7 @@ help: @echo " make check - Check code with ruff" @echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)" @echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)" + @echo " make type-check-core - Run core type checks (basedpyright, mypy)" @echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/)" @echo "" @echo "Docker Build Targets:" diff --git a/README.md b/README.md index bef8f6b782..d9848a6c78 100644 --- a/README.md +++ b/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify is an open-source LLM app development platform. Its intuitive interface combines AI workflow, RAG pipeline, agent capabilities, model management, observability features (including [Opik](https://www.comet.com/docs/opik/integrations/dify), [Langfuse](https://docs.langfuse.com), and [Arize Phoenix](https://docs.arize.com/phoenix)) and more, letting you quickly go from prototype to production. Here's a list of the core features: diff --git a/api/.env.example b/api/.env.example index 9672a99d55..c6541731e6 100644 --- a/api/.env.example +++ b/api/.env.example @@ -127,7 +127,8 @@ ALIYUN_OSS_AUTH_VERSION=v1 ALIYUN_OSS_REGION=your-region # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path -ALIYUN_CLOUDBOX_ID=your-cloudbox-id +# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox. +#ALIYUN_CLOUDBOX_ID=your-cloudbox-id # Google Storage configuration GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name diff --git a/api/.importlinter b/api/.importlinter index a836d09088..5e06947d94 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -1,202 +1,14 @@ [importlinter] root_packages = core - dify_graph + constants + context configs controllers extensions + factories + libs models tasks services include_external_packages = True - -[importlinter:contract:workflow] -name = Workflow -type=layers -layers = - graph_engine - graph_events - graph - nodes - node_events - runtime - entities -containers = - dify_graph -ignore_imports = - dify_graph.nodes.base.node -> dify_graph.graph_events - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_events - dify_graph.nodes.loop.loop_node -> dify_graph.graph_events - - dify_graph.nodes.iteration.iteration_node -> dify_graph.graph_engine - dify_graph.nodes.loop.loop_node -> dify_graph.graph_engine - # TODO(QuantumGhost): fix the import violation later - dify_graph.entities.pause_reason -> dify_graph.nodes.human_input.entities - -[importlinter:contract:workflow-infrastructure-dependencies] -name = Workflow Infrastructure Dependencies -type = forbidden -source_modules = - dify_graph -forbidden_modules = - extensions.ext_database - extensions.ext_redis -allow_indirect_imports = True -ignore_imports = - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - -[importlinter:contract:workflow-external-imports] -name = Workflow External Imports -type = forbidden -source_modules = - dify_graph -forbidden_modules = - configs - controllers - extensions - models - services - tasks - core.agent - core.app - core.base - core.callback_handler - core.datasource - core.db - core.entities - core.errors - core.extension - core.external_data_tool - core.file - core.helper - core.hosting_configuration - core.indexing_runner - core.llm_generator - core.logging - core.mcp - core.memory - core.moderation - core.ops - core.plugin - core.prompt - core.provider_manager - core.rag - core.repositories - core.schemas - core.tools - core.trigger - core.variables -ignore_imports = - dify_graph.nodes.llm.llm_utils -> core.model_manager - dify_graph.nodes.llm.protocols -> core.model_manager - dify_graph.nodes.llm.llm_utils -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.llm.node -> core.tools.signature - dify_graph.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler - dify_graph.nodes.tool.tool_node -> core.tools.tool_engine - dify_graph.nodes.tool.tool_node -> core.tools.tool_manager - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> dify_graph.model_runtime.model_providers.__base.large_language_model - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.simple_prompt_transform - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager - dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager - dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors - dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output - dify_graph.nodes.llm.node -> core.model_manager - dify_graph.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.llm.node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.parameter_extractor.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.question_classifier.entities -> core.prompt.entities.advanced_prompt_entities - dify_graph.nodes.question_classifier.question_classifier_node -> core.prompt.utils.prompt_message_util - dify_graph.nodes.llm.node -> models.dataset - dify_graph.nodes.llm.file_saver -> core.tools.signature - dify_graph.nodes.llm.file_saver -> core.tools.tool_file_manager - dify_graph.nodes.tool.tool_node -> core.tools.errors - dify_graph.nodes.llm.node -> extensions.ext_database - dify_graph.nodes.llm.node -> models.model - dify_graph.nodes.tool.tool_node -> services - dify_graph.model_runtime.model_providers.__base.ai_model -> configs - dify_graph.model_runtime.model_providers.__base.ai_model -> extensions.ext_redis - dify_graph.model_runtime.model_providers.__base.large_language_model -> configs - dify_graph.model_runtime.model_providers.__base.text_embedding_model -> core.entities.embedding_type - dify_graph.model_runtime.model_providers.model_provider_factory -> configs - dify_graph.model_runtime.model_providers.model_provider_factory -> extensions.ext_redis - dify_graph.model_runtime.model_providers.model_provider_factory -> models.provider_ids - -[importlinter:contract:rsc] -name = RSC -type = layers -layers = - graph_engine - response_coordinator -containers = - dify_graph.graph_engine - -[importlinter:contract:worker] -name = Worker -type = layers -layers = - graph_engine - worker -containers = - dify_graph.graph_engine - -[importlinter:contract:graph-engine-architecture] -name = Graph Engine Architecture -type = layers -layers = - graph_engine - orchestration - command_processing - event_management - error_handler - graph_traversal - graph_state_manager - worker_management - domain -containers = - dify_graph.graph_engine - -[importlinter:contract:domain-isolation] -name = Domain Model Isolation -type = forbidden -source_modules = - dify_graph.graph_engine.domain -forbidden_modules = - dify_graph.graph_engine.worker_management - dify_graph.graph_engine.command_channels - dify_graph.graph_engine.layers - dify_graph.graph_engine.protocols - -[importlinter:contract:worker-management] -name = Worker Management -type = forbidden -source_modules = - dify_graph.graph_engine.worker_management -forbidden_modules = - dify_graph.graph_engine.orchestration - dify_graph.graph_engine.command_processing - dify_graph.graph_engine.event_management - - -[importlinter:contract:graph-traversal-components] -name = Graph Traversal Components -type = layers -layers = - edge_processor - skip_propagator -containers = - dify_graph.graph_engine.graph_traversal - -[importlinter:contract:command-channels] -name = Command Channels Independence -type = independence -modules = - dify_graph.graph_engine.command_channels.in_memory_channel - dify_graph.graph_engine.command_channels.redis_channel diff --git a/api/.ruff.toml b/api/.ruff.toml index b0947eb619..4b1252a861 100644 --- a/api/.ruff.toml +++ b/api/.ruff.toml @@ -100,7 +100,7 @@ ignore = [ "configs/*" = [ "N802", # invalid-function-name ] -"dify_graph/model_runtime/callbacks/base_callback.py" = ["T201"] +"graphon/model_runtime/callbacks/base_callback.py" = ["T201"] "core/workflow/callbacks/workflow_logging_callback.py" = ["T201"] "libs/gmpy2_pkcs10aep_cipher.py" = [ "N803", # invalid-argument-name diff --git a/api/app_factory.py b/api/app_factory.py index 066eb2ae2c..76838f9925 100644 --- a/api/app_factory.py +++ b/api/app_factory.py @@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_compress, ext_database, + ext_enterprise_telemetry, ext_fastopenapi, ext_forward_refs, ext_hosting_provider, @@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp): ext_commands, ext_fastopenapi, ext_otel, + ext_enterprise_telemetry, ext_request_logging, ext_session_factory, ] diff --git a/api/configs/app_config.py b/api/configs/app_config.py index d3b1cf9d5b..831f0a49e0 100644 --- a/api/configs/app_config.py +++ b/api/configs/app_config.py @@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings from libs.file_utils import search_file_upwards from .deploy import DeploymentConfig -from .enterprise import EnterpriseFeatureConfig +from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig from .extra import ExtraServiceConfig from .feature import FeatureConfig from .middleware import MiddlewareConfig @@ -73,6 +73,8 @@ class DifyConfig( # Enterprise feature configs # **Before using, please contact business@dify.ai by email to inquire about licensing matters.** EnterpriseFeatureConfig, + # Enterprise telemetry configs + EnterpriseTelemetryConfig, ): model_config = SettingsConfigDict( # read from dotenv format config file diff --git a/api/configs/enterprise/__init__.py b/api/configs/enterprise/__init__.py index f8447c6979..8a6a921a4e 100644 --- a/api/configs/enterprise/__init__.py +++ b/api/configs/enterprise/__init__.py @@ -22,3 +22,52 @@ class EnterpriseFeatureConfig(BaseSettings): ENTERPRISE_REQUEST_TIMEOUT: int = Field( ge=1, description="Maximum timeout in seconds for enterprise requests", default=5 ) + + +class EnterpriseTelemetryConfig(BaseSettings): + """ + Configuration for enterprise telemetry. + """ + + ENTERPRISE_TELEMETRY_ENABLED: bool = Field( + description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).", + default=False, + ) + + ENTERPRISE_OTLP_ENDPOINT: str = Field( + description="Enterprise OTEL collector endpoint.", + default="", + ) + + ENTERPRISE_OTLP_HEADERS: str = Field( + description="Auth headers for OTLP export (key=value,key2=value2).", + default="", + ) + + ENTERPRISE_OTLP_PROTOCOL: str = Field( + description="OTLP protocol: 'http' or 'grpc' (default: http).", + default="http", + ) + + ENTERPRISE_OTLP_API_KEY: str = Field( + description="Bearer token for enterprise OTLP export authentication.", + default="", + ) + + ENTERPRISE_INCLUDE_CONTENT: bool = Field( + description="Include input/output content in traces (privacy toggle).", + # Setting the default value to False to avoid accidentally log PII data in traces. + default=False, + ) + + ENTERPRISE_SERVICE_NAME: str = Field( + description="Service name for OTEL resource.", + default="dify", + ) + + ENTERPRISE_OTEL_SAMPLING_RATE: float = Field( + description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).", + default=1.0, + ge=0.0, + le=1.0, + ) diff --git a/api/context/__init__.py b/api/context/__init__.py index 969e5f583d..8df37138e8 100644 --- a/api/context/__init__.py +++ b/api/context/__init__.py @@ -1,74 +1,36 @@ """ -Core Context - Framework-agnostic context management. +Application-layer context adapters. -This module provides context management that is independent of any specific -web framework. Framework-specific implementations register their context -capture functions at application initialization time. - -This ensures the workflow layer remains completely decoupled from Flask -or any other web framework. +Concrete execution-context implementations live here so `graphon` only +depends on injected context managers rather than framework state capture. """ -import contextvars -from collections.abc import Callable - -from dify_graph.context.execution_context import ( +from context.execution_context import ( + AppContext, + ContextProviderNotFoundError, ExecutionContext, + ExecutionContextBuilder, IExecutionContext, NullAppContext, + capture_current_context, + read_context, + register_context, + register_context_capturer, + reset_context_provider, ) - -# Global capturer function - set by framework-specific modules -_capturer: Callable[[], IExecutionContext] | None = None - - -def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """ - Register a context capture function. - - This should be called by framework-specific modules (e.g., Flask) - during application initialization. - - Args: - capturer: Function that captures current context and returns IExecutionContext - """ - global _capturer - _capturer = capturer - - -def capture_current_context() -> IExecutionContext: - """ - Capture current execution context. - - This function uses the registered context capturer. If no capturer - is registered, it returns a minimal context with only contextvars - (suitable for non-framework environments like tests or standalone scripts). - - Returns: - IExecutionContext with captured context - """ - if _capturer is None: - # No framework registered - return minimal context - return ExecutionContext( - app_context=NullAppContext(), - context_vars=contextvars.copy_context(), - ) - - return _capturer() - - -def reset_context_provider() -> None: - """ - Reset the context capturer. - - This is primarily useful for testing to ensure a clean state. - """ - global _capturer - _capturer = None - +from context.models import SandboxContext __all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "SandboxContext", "capture_current_context", + "read_context", + "register_context", "register_context_capturer", "reset_context_provider", ] diff --git a/api/dify_graph/context/execution_context.py b/api/context/execution_context.py similarity index 60% rename from api/dify_graph/context/execution_context.py rename to api/context/execution_context.py index e3007530f0..ba9a24d4f3 100644 --- a/api/dify_graph/context/execution_context.py +++ b/api/context/execution_context.py @@ -1,5 +1,8 @@ """ -Execution Context - Abstracted context management for workflow execution. +Application-layer execution context adapters. + +Concrete context capture lives outside `graphon` so the graph package only +consumes injected context managers when it needs to preserve thread-local state. """ import contextvars @@ -16,33 +19,33 @@ class AppContext(ABC): """ Abstract application context interface. - This abstraction allows workflow execution to work with or without Flask - by providing a common interface for application context management. + Application adapters can implement this to restore framework-specific state + such as Flask app context around worker execution. """ @abstractmethod def get_config(self, key: str, default: Any = None) -> Any: """Get configuration value by key.""" - pass + raise NotImplementedError @abstractmethod def get_extension(self, name: str) -> Any: - """Get Flask extension by name (e.g., 'db', 'cache').""" - pass + """Get application extension by name.""" + raise NotImplementedError @abstractmethod def enter(self) -> AbstractContextManager[None]: """Enter the application context.""" - pass + raise NotImplementedError @runtime_checkable class IExecutionContext(Protocol): """ - Protocol for execution context. + Protocol for enterable execution context objects. - This protocol defines the interface that all execution contexts must implement, - allowing both ExecutionContext and FlaskExecutionContext to be used interchangeably. + Concrete implementations may carry extra framework state, but callers only + depend on standard context-manager behavior plus optional user metadata. """ def __enter__(self) -> "IExecutionContext": @@ -62,14 +65,10 @@ class IExecutionContext(Protocol): @final class ExecutionContext: """ - Execution context for workflow execution in worker threads. + Generic execution context used by application-layer adapters. - This class encapsulates all context needed for workflow execution: - - Application context (Flask app or standalone) - - Context variables for Python contextvars - - User information (optional) - - It is designed to be serializable and passable to worker threads. + It restores captured `contextvars` and optionally enters an application + context before the worker executes graph logic. """ def __init__( @@ -78,14 +77,6 @@ class ExecutionContext: context_vars: contextvars.Context | None = None, user: Any = None, ) -> None: - """ - Initialize execution context. - - Args: - app_context: Application context (Flask or standalone) - context_vars: Python contextvars to preserve - user: User object (optional) - """ self._app_context = app_context self._context_vars = context_vars self._user = user @@ -98,27 +89,21 @@ class ExecutionContext: @property def context_vars(self) -> contextvars.Context | None: - """Get context variables.""" + """Get captured context variables.""" return self._context_vars @property def user(self) -> Any: - """Get user object.""" + """Get captured user object.""" return self._user @contextmanager def enter(self) -> Generator[None, None, None]: - """ - Enter this execution context. - - This is a convenience method that creates a context manager. - """ - # Restore context variables if provided + """Enter this execution context.""" if self._context_vars: for var, val in self._context_vars.items(): var.set(val) - # Enter app context if available if self._app_context is not None: with self._app_context.enter(): yield @@ -141,18 +126,10 @@ class ExecutionContext: class NullAppContext(AppContext): """ - Null implementation of AppContext for non-Flask environments. - - This is used when running without Flask (e.g., in tests or standalone mode). + Null application context for non-framework environments. """ def __init__(self, config: dict[str, Any] | None = None) -> None: - """ - Initialize null app context. - - Args: - config: Optional configuration dictionary - """ self._config = config or {} self._extensions: dict[str, Any] = {} @@ -165,7 +142,7 @@ class NullAppContext(AppContext): return self._extensions.get(name) def set_extension(self, name: str, extension: Any) -> None: - """Set extension by name.""" + """Register an extension for tests or standalone execution.""" self._extensions[name] = extension @contextmanager @@ -176,9 +153,7 @@ class NullAppContext(AppContext): class ExecutionContextBuilder: """ - Builder for creating ExecutionContext instances. - - This provides a fluent API for building execution contexts. + Builder for creating `ExecutionContext` instances. """ def __init__(self) -> None: @@ -211,63 +186,42 @@ class ExecutionContextBuilder: _capturer: Callable[[], IExecutionContext] | None = None - -# Tenant-scoped providers using tuple keys for clarity and constant-time lookup. -# Key mapping: -# (name, tenant_id) -> provider -# - name: namespaced identifier (recommend prefixing, e.g. "workflow.sandbox") -# - tenant_id: tenant identifier string -# Value: -# provider: Callable[[], BaseModel] returning the typed context value -# Type-safety note: -# - This registry cannot enforce that all providers for a given name return the same BaseModel type. -# - Implementors SHOULD provide typed wrappers around register/read (like Go's context best practice), -# e.g. def register_sandbox_ctx(tenant_id: str, p: Callable[[], SandboxContext]) and -# def read_sandbox_ctx(tenant_id: str) -> SandboxContext. _tenant_context_providers: dict[tuple[str, str], Callable[[], BaseModel]] = {} T = TypeVar("T", bound=BaseModel) class ContextProviderNotFoundError(KeyError): - """Raised when a tenant-scoped context provider is missing for a given (name, tenant_id).""" + """Raised when a tenant-scoped context provider is missing.""" pass def register_context_capturer(capturer: Callable[[], IExecutionContext]) -> None: - """Register a single enterable execution context capturer (e.g., Flask).""" + """Register an enterable execution context capturer.""" global _capturer _capturer = capturer def register_context(name: str, tenant_id: str, provider: Callable[[], BaseModel]) -> None: - """Register a tenant-specific provider for a named context. - - Tip: use a namespaced "name" (e.g., "workflow.sandbox") to avoid key collisions. - Consider adding a typed wrapper for this registration in your feature module. - """ + """Register a tenant-specific provider for a named context.""" _tenant_context_providers[(name, tenant_id)] = provider def read_context(name: str, *, tenant_id: str) -> BaseModel: - """ - Read a context value for a specific tenant. - - Raises KeyError if the provider for (name, tenant_id) is not registered. - """ - prov = _tenant_context_providers.get((name, tenant_id)) - if prov is None: + """Read a context value for a specific tenant.""" + provider = _tenant_context_providers.get((name, tenant_id)) + if provider is None: raise ContextProviderNotFoundError(f"Context provider '{name}' not registered for tenant '{tenant_id}'") - return prov() + return provider() def capture_current_context() -> IExecutionContext: """ Capture current execution context from the calling environment. - If a capturer is registered (e.g., Flask), use it. Otherwise, return a minimal - context with NullAppContext + copy of current contextvars. + If no framework adapter is registered, return a minimal context that only + restores `contextvars`. """ if _capturer is None: return ExecutionContext( @@ -278,7 +232,22 @@ def capture_current_context() -> IExecutionContext: def reset_context_provider() -> None: - """Reset the capturer and all tenant-scoped context providers (primarily for tests).""" + """Reset the capturer and tenant-scoped providers.""" global _capturer _capturer = None _tenant_context_providers.clear() + + +__all__ = [ + "AppContext", + "ContextProviderNotFoundError", + "ExecutionContext", + "ExecutionContextBuilder", + "IExecutionContext", + "NullAppContext", + "capture_current_context", + "read_context", + "register_context", + "register_context_capturer", + "reset_context_provider", +] diff --git a/api/context/flask_app_context.py b/api/context/flask_app_context.py index 324a9ee8b4..eddd6448d8 100644 --- a/api/context/flask_app_context.py +++ b/api/context/flask_app_context.py @@ -10,11 +10,7 @@ from typing import Any, final from flask import Flask, current_app, g -from dify_graph.context import register_context_capturer -from dify_graph.context.execution_context import ( - AppContext, - IExecutionContext, -) +from context.execution_context import AppContext, IExecutionContext, register_context_capturer @final diff --git a/api/dify_graph/context/models.py b/api/context/models.py similarity index 100% rename from api/dify_graph/context/models.py rename to api/context/models.py diff --git a/api/contexts/__init__.py b/api/contexts/__init__.py index c52dcf8a57..764f9f8ee2 100644 --- a/api/contexts/__init__.py +++ b/api/contexts/__init__.py @@ -6,7 +6,6 @@ from contexts.wrapper import RecyclableContextVar if TYPE_CHECKING: from core.datasource.__base.datasource_provider import DatasourcePluginProviderController - from core.plugin.entities.plugin_daemon import PluginModelProviderEntity from core.tools.plugin_tool.provider import PluginToolProviderController from core.trigger.provider import PluginTriggerProviderController @@ -20,14 +19,6 @@ plugin_tool_providers: RecyclableContextVar[dict[str, "PluginToolProviderControl plugin_tool_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar(ContextVar("plugin_tool_providers_lock")) -plugin_model_providers: RecyclableContextVar[list["PluginModelProviderEntity"] | None] = RecyclableContextVar( - ContextVar("plugin_model_providers") -) - -plugin_model_providers_lock: RecyclableContextVar[Lock] = RecyclableContextVar( - ContextVar("plugin_model_providers_lock") -) - datasource_plugin_providers: RecyclableContextVar[dict[str, "DatasourcePluginProviderController"]] = ( RecyclableContextVar(ContextVar("datasource_plugin_providers")) ) diff --git a/api/controllers/common/fields.py b/api/controllers/common/fields.py index ff5326dade..7348ef62aa 100644 --- a/api/controllers/common/fields.py +++ b/api/controllers/common/fields.py @@ -2,9 +2,9 @@ from __future__ import annotations from typing import Any, TypeAlias +from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field -from dify_graph.file import helpers as file_helpers from models.model import IconType JSONValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any] diff --git a/api/controllers/console/app/app.py b/api/controllers/console/app/app.py index 7e41260eeb..738e77b371 100644 --- a/api/controllers/console/app/app.py +++ b/api/controllers/console/app/app.py @@ -5,6 +5,8 @@ from typing import Any, Literal, TypeAlias from flask import request from flask_restx import Resource +from graphon.enums import WorkflowExecutionStatus +from graphon.file import helpers as file_helpers from pydantic import AliasChoices, BaseModel, ConfigDict, Field, computed_field, field_validator from sqlalchemy import select from sqlalchemy.orm import Session @@ -26,8 +28,6 @@ from controllers.console.wraps import ( from core.ops.ops_trace_manager import OpsTraceManager from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.trigger.constants import TRIGGER_NODE_TYPES -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App, DatasetPermissionEnum, Workflow diff --git a/api/controllers/console/app/audio.py b/api/controllers/console/app/audio.py index 2c5e8d29ee..78ddb904e1 100644 --- a/api/controllers/console/app/audio.py +++ b/api/controllers/console/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource, fields +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -22,7 +23,6 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from libs.login import login_required from models import App, AppMode from services.audio_service import AudioService diff --git a/api/controllers/console/app/completion.py b/api/controllers/console/app/completion.py index 4d7ddfea13..d83925d173 100644 --- a/api/controllers/console/app/completion.py +++ b/api/controllers/console/app/completion.py @@ -3,6 +3,7 @@ from typing import Any, Literal from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -26,7 +27,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from libs.login import current_user, login_required diff --git a/api/controllers/console/app/generator.py b/api/controllers/console/app/generator.py index 442d0d2324..7101d5df7b 100644 --- a/api/controllers/console/app/generator.py +++ b/api/controllers/console/app/generator.py @@ -1,6 +1,7 @@ from collections.abc import Sequence from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from controllers.console import console_ns @@ -18,7 +19,6 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs.login import current_account_with_tenant, login_required from models import App diff --git a/api/controllers/console/app/message.py b/api/controllers/console/app/message.py index 736e7dbe17..2afe276742 100644 --- a/api/controllers/console/app/message.py +++ b/api/controllers/console/app/message.py @@ -3,6 +3,7 @@ from typing import Literal from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from sqlalchemy import exists, func, select from werkzeug.exceptions import InternalServerError, NotFound @@ -24,7 +25,6 @@ from controllers.console.wraps import ( ) from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from fields.raws import FilesContainedField from libs.helper import TimestampField, uuid_value diff --git a/api/controllers/console/app/model_config.py b/api/controllers/console/app/model_config.py index e9bd30ba7e..8bb5aa2c1b 100644 --- a/api/controllers/console/app/model_config.py +++ b/api/controllers/console/app/model_config.py @@ -88,6 +88,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_tenant_id, @@ -127,6 +128,7 @@ class ModelConfigResource(Resource): tenant_id=current_tenant_id, app_id=app_model.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) except Exception: continue diff --git a/api/controllers/console/app/workflow.py b/api/controllers/console/app/workflow.py index d59aa44718..1f5a84c0b2 100644 --- a/api/controllers/console/app/workflow.py +++ b/api/controllers/console/app/workflow.py @@ -5,6 +5,10 @@ from typing import Any from flask import abort, request from flask_restx import Resource, fields, marshal_with +from graphon.enums import NodeType +from graphon.file import File +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -20,6 +24,7 @@ from core.app.app_config.features.file_upload.manager import FileUploadConfigMan from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.helper.trace_id_helper import get_external_trace_id from core.plugin.impl.exc import PluginInvokeError from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE @@ -29,10 +34,6 @@ from core.trigger.debug.event_selectors import ( create_event_poller, select_trigger_debug_events, ) -from dify_graph.enums import NodeType -from dify_graph.file.models import File -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from extensions.ext_redis import redis_client from factories import file_factory, variable_factory @@ -51,6 +52,7 @@ from services.errors.llm import InvokeRateLimitError from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() LISTENING_RETRY_IN = 2000 DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" RESTORE_SOURCE_WORKFLOW_MUST_BE_PUBLISHED_MESSAGE = "source workflow must be published" @@ -204,6 +206,7 @@ def _parse_file(workflow: Workflow, files: list[dict] | None = None) -> Sequence mappings=files, tenant_id=workflow.tenant_id, config=file_extra_config, + access_controller=_file_access_controller, ) return file_objs diff --git a/api/controllers/console/app/workflow_app_log.py b/api/controllers/console/app/workflow_app_log.py index 9b148c3f18..f0e26c86a5 100644 --- a/api/controllers/console/app/workflow_app_log.py +++ b/api/controllers/console/app/workflow_app_log.py @@ -3,13 +3,13 @@ from datetime import datetime from dateutil.parser import isoparse from flask import request from flask_restx import Resource, marshal_with +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy.orm import Session from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_database import db from fields.workflow_app_log_fields import ( build_workflow_app_log_pagination_model, diff --git a/api/controllers/console/app/workflow_draft_variable.py b/api/controllers/console/app/workflow_draft_variable.py index b78d97a382..4052897e9a 100644 --- a/api/controllers/console/app/workflow_draft_variable.py +++ b/api/controllers/console/app/workflow_draft_variable.py @@ -5,6 +5,10 @@ from typing import Any, NoReturn, ParamSpec, TypeVar from flask import Response, request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.file import helpers as file_helpers +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session @@ -15,11 +19,8 @@ from controllers.console.app.error import ( from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.file import helpers as file_helpers -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment -from dify_graph.variables.types import SegmentType +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type @@ -30,6 +31,7 @@ from services.workflow_draft_variable_service import WorkflowDraftVariableList, from services.workflow_service import WorkflowService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}" @@ -389,13 +391,21 @@ class VariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=app_model.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=app_model.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/app/workflow_run.py b/api/controllers/console/app/workflow_run.py index 7ac653395e..83e8bedc11 100644 --- a/api/controllers/console/app/workflow_run.py +++ b/api/controllers/console/app/workflow_run.py @@ -1,8 +1,10 @@ from datetime import UTC, datetime, timedelta -from typing import Literal, cast +from typing import Literal, TypedDict, cast from flask import request from flask_restx import Resource, fields, marshal_with +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from pydantic import BaseModel, Field, field_validator from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -12,8 +14,7 @@ from controllers.console import console_ns from controllers.console.app.wraps import get_app_model from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import NotFoundError -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus +from core.workflow.human_input_forms import load_form_tokens_by_form_id as _load_form_tokens_by_form_id from extensions.ext_database import db from fields.end_user_fields import simple_end_user_fields from fields.member_fields import simple_account_fields @@ -172,6 +173,23 @@ console_ns.schema_model( ) +class HumanInputPauseTypeResponse(TypedDict): + type: Literal["human_input"] + form_id: str + backstage_input_url: str | None + + +class PausedNodeResponse(TypedDict): + node_id: str + node_title: str + pause_type: HumanInputPauseTypeResponse + + +class WorkflowPauseDetailsResponse(TypedDict): + paused_at: str | None + paused_nodes: list[PausedNodeResponse] + + @console_ns.route("/apps//advanced-chat/workflow-runs") class AdvancedChatAppWorkflowRunListApi(Resource): @console_ns.doc("get_advanced_chat_workflow_runs") @@ -489,18 +507,22 @@ class ConsoleWorkflowPauseDetailsApi(Resource): # Check if workflow is suspended is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED if not is_paused: - return { + empty_response: WorkflowPauseDetailsResponse = { "paused_at": None, "paused_nodes": [], - }, 200 + } + return empty_response, 200 pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id) pause_reasons = pause_entity.get_pause_reasons() if pause_entity else [] + form_tokens_by_form_id = _load_form_tokens_by_form_id( + [reason.form_id for reason in pause_reasons if isinstance(reason, HumanInputRequired)] + ) # Build response paused_at = pause_entity.paused_at if pause_entity else None - paused_nodes = [] - response = { + paused_nodes: list[PausedNodeResponse] = [] + response: WorkflowPauseDetailsResponse = { "paused_at": paused_at.isoformat() + "Z" if paused_at else None, "paused_nodes": paused_nodes, } @@ -514,7 +536,9 @@ class ConsoleWorkflowPauseDetailsApi(Resource): "pause_type": { "type": "human_input", "form_id": reason.form_id, - "backstage_input_url": _build_backstage_input_url(reason.form_token), + "backstage_input_url": _build_backstage_input_url( + form_tokens_by_form_id.get(reason.form_id) + ), }, } ) diff --git a/api/controllers/console/auth/oauth_server.py b/api/controllers/console/auth/oauth_server.py index 6e59d4203c..686b865871 100644 --- a/api/controllers/console/auth/oauth_server.py +++ b/api/controllers/console/auth/oauth_server.py @@ -4,11 +4,11 @@ from typing import Concatenate, ParamSpec, TypeVar from flask import jsonify, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from werkzeug.exceptions import BadRequest, NotFound from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models import Account from models.model import OAuthProviderApp diff --git a/api/controllers/console/datasets/datasets.py b/api/controllers/console/datasets/datasets.py index 27c772fbe0..f23c7eb431 100644 --- a/api/controllers/console/datasets/datasets.py +++ b/api/controllers/console/datasets/datasets.py @@ -2,6 +2,7 @@ from typing import Any, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, field_validator from sqlalchemy import func, select from werkzeug.exceptions import Forbidden, NotFound @@ -25,13 +26,12 @@ from controllers.console.wraps import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.indexing_runner import IndexingRunner -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.datasource.vdb.vector_type import VectorType from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.app_fields import app_detail_kernel_fields, related_app_list from fields.dataset_fields import ( @@ -332,7 +332,7 @@ class DatasetListApi(Resource): ) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -446,7 +446,7 @@ class DatasetApi(Resource): data.update({"partial_member_list": part_users_list}) # check embedding setting - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=current_tenant_id) configurations = provider_manager.get_configurations(tenant_id=current_tenant_id) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) diff --git a/api/controllers/console/datasets/datasets_document.py b/api/controllers/console/datasets/datasets_document.py index 897724182f..ab367d8483 100644 --- a/api/controllers/console/datasets/datasets_document.py +++ b/api/controllers/console/datasets/datasets_document.py @@ -9,6 +9,8 @@ from uuid import UUID import sqlalchemy as sa from flask import request, send_file from flask_restx import Resource, fields, marshal, marshal_with +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import BaseModel, Field from sqlalchemy import asc, desc, func, select from werkzeug.exceptions import Forbidden, NotFound @@ -28,8 +30,6 @@ from core.plugin.impl.exc import PluginDaemonClientSideError from core.rag.extractor.entity.datasource_type import DatasourceType from core.rag.extractor.entity.extract_setting import ExtractSetting, NotionInfo, WebsiteInfo from core.rag.index_processor.constant.index_type import IndexTechniqueType -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from fields.dataset_fields import dataset_fields from fields.document_fields import ( @@ -454,7 +454,7 @@ class DatasetInitApi(Resource): if knowledge_config.embedding_model is None or knowledge_config.embedding_model_provider is None: raise ValueError("embedding model and embedding model provider are required for high quality indexing.") try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=knowledge_config.embedding_model_provider, diff --git a/api/controllers/console/datasets/datasets_segments.py b/api/controllers/console/datasets/datasets_segments.py index 7333fcaa07..c5f4e3a6e2 100644 --- a/api/controllers/console/datasets/datasets_segments.py +++ b/api/controllers/console/datasets/datasets_segments.py @@ -2,6 +2,7 @@ import uuid from flask import request from flask_restx import Resource, marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import String, cast, func, or_, select from sqlalchemy.dialects.postgresql import JSONB @@ -27,7 +28,6 @@ from controllers.console.wraps import ( from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.segment_fields import child_chunk_fields, segment_fields @@ -283,7 +283,7 @@ class DatasetDocumentSegmentApi(Resource): if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -336,7 +336,7 @@ class DatasetDocumentSegmentAddApi(Resource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -387,7 +387,7 @@ class DatasetDocumentSegmentUpdateApi(Resource): if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -572,7 +572,7 @@ class ChildChunkAddApi(Resource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, diff --git a/api/controllers/console/datasets/external.py b/api/controllers/console/datasets/external.py index 86090bcd10..fc6896f123 100644 --- a/api/controllers/console/datasets/external.py +++ b/api/controllers/console/datasets/external.py @@ -25,7 +25,7 @@ from libs.login import current_account_with_tenant, login_required from services.dataset_service import DatasetService from services.external_knowledge_service import ExternalDatasetService from services.hit_testing_service import HitTestingService -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService def _build_dataset_detail_model(): @@ -86,7 +86,7 @@ class ExternalHitTestingPayload(BaseModel): class BedrockRetrievalPayload(BaseModel): - retrieval_setting: dict[str, object] + retrieval_setting: "BedrockRetrievalSetting" query: str knowledge_id: str diff --git a/api/controllers/console/datasets/hit_testing_base.py b/api/controllers/console/datasets/hit_testing_base.py index cd568cf835..8fb3699849 100644 --- a/api/controllers/console/datasets/hit_testing_base.py +++ b/api/controllers/console/datasets/hit_testing_base.py @@ -2,6 +2,7 @@ import logging from typing import Any from flask_restx import marshal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -19,7 +20,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.hit_testing_fields import hit_testing_record_fields from libs.login import current_user from models.account import Account diff --git a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py index a4498005d8..1976a6bc8a 100644 --- a/api/controllers/console/datasets/rag_pipeline/datasource_auth.py +++ b/api/controllers/console/datasets/rag_pipeline/datasource_auth.py @@ -2,6 +2,8 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden, NotFound @@ -10,8 +12,6 @@ from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, edit_permission_required, setup_required from core.plugin.impl.oauth import OAuthHandler -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.provider_ids import DatasourceProviderID from services.datasource_provider_service import DatasourceProviderService diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py index c5dadb75f5..f12cbd3495 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_draft_variable.py @@ -3,6 +3,7 @@ from typing import Any, NoReturn from flask import Response, request from flask_restx import Resource, marshal, marshal_with +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -21,8 +22,8 @@ from controllers.console.app.workflow_draft_variable import ( from controllers.console.datasets.wraps import get_rag_pipeline from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.app.file_access import DatabaseFileAccessController +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from factories.file_factory import build_from_mapping, build_from_mappings from factories.variable_factory import build_segment_with_type @@ -33,6 +34,7 @@ from services.rag_pipeline.rag_pipeline import RagPipelineService from services.workflow_draft_variable_service import WorkflowDraftVariableList, WorkflowDraftVariableService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() def _create_pagination_parser(): @@ -223,13 +225,21 @@ class RagPipelineVariableApi(Resource): if variable.value_type == SegmentType.FILE: if not isinstance(raw_value, dict): raise InvalidArgumentError(description=f"expected dict for file, got {type(raw_value)}") - raw_value = build_from_mapping(mapping=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mapping( + mapping=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) elif variable.value_type == SegmentType.ARRAY_FILE: if not isinstance(raw_value, list): raise InvalidArgumentError(description=f"expected list for files, got {type(raw_value)}") if len(raw_value) > 0 and not isinstance(raw_value[0], dict): raise InvalidArgumentError(description=f"expected dict for files[0], got {type(raw_value)}") - raw_value = build_from_mappings(mappings=raw_value, tenant_id=pipeline.tenant_id) + raw_value = build_from_mappings( + mappings=raw_value, + tenant_id=pipeline.tenant_id, + access_controller=_file_access_controller, + ) new_value = build_segment_with_type(variable.value_type, raw_value) draft_var_srv.update_variable(variable, name=new_name, value=new_value) db.session.commit() diff --git a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py index 3912cc73ca..8efb59a8e9 100644 --- a/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py +++ b/api/controllers/console/datasets/rag_pipeline/rag_pipeline_workflow.py @@ -4,6 +4,7 @@ from typing import Any, Literal, cast from flask import abort, request from flask_restx import Resource, marshal_with # type: ignore +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden, InternalServerError, NotFound @@ -37,7 +38,6 @@ from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpErr from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_generator import PipelineGenerator from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from factories import variable_factory from libs import helper @@ -53,6 +53,7 @@ from services.rag_pipeline.pipeline_generate_service import PipelineGenerateServ from services.rag_pipeline.rag_pipeline import RagPipelineService from services.rag_pipeline.rag_pipeline_manage_service import RagPipelineManageService from services.rag_pipeline.rag_pipeline_transform_service import RagPipelineTransformService +from services.workflow_service import DraftWorkflowDeletionError, WorkflowInUseError, WorkflowService logger = logging.getLogger(__name__) @@ -781,7 +782,38 @@ class RagPipelineByIdApi(Resource): # Commit the transaction in the controller session.commit() - return workflow + return workflow + + @setup_required + @login_required + @account_initialization_required + @edit_permission_required + @get_rag_pipeline + def delete(self, pipeline: Pipeline, workflow_id: str): + """ + Delete a published workflow version that is not currently active on the pipeline. + """ + if pipeline.workflow_id == workflow_id: + abort(400, description=f"Cannot delete workflow that is currently in use by pipeline '{pipeline.id}'") + + workflow_service = WorkflowService() + + with Session(db.engine) as session: + try: + workflow_service.delete_workflow( + session=session, + workflow_id=workflow_id, + tenant_id=pipeline.tenant_id, + ) + session.commit() + except WorkflowInUseError as e: + abort(400, description=str(e)) + except DraftWorkflowDeletionError as e: + abort(400, description=str(e)) + except ValueError as e: + raise NotFound(str(e)) + + return None, 204 @console_ns.route("/rag/pipelines//workflows/published/processing/parameters") diff --git a/api/controllers/console/explore/audio.py b/api/controllers/console/explore/audio.py index ffb9e5bb6e..b1b01b5f51 100644 --- a/api/controllers/console/explore/audio.py +++ b/api/controllers/console/explore/audio.py @@ -1,6 +1,7 @@ import logging from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -19,7 +20,6 @@ from controllers.console.app.error import ( ) from controllers.console.explore.wraps import InstalledAppResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.audio import ( AudioTooLargeServiceError, diff --git a/api/controllers/console/explore/completion.py b/api/controllers/console/explore/completion.py index fcd52d2818..eacd7332fe 100644 --- a/api/controllers/console/explore/completion.py +++ b/api/controllers/console/explore/completion.py @@ -2,6 +2,7 @@ import logging from typing import Any, Literal from uuid import UUID +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -24,7 +25,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from libs import helper from libs.datetime_utils import naive_utc_now diff --git a/api/controllers/console/explore/message.py b/api/controllers/console/explore/message.py index 15e1aea361..fcbefcda33 100644 --- a/api/controllers/console/explore/message.py +++ b/api/controllers/console/explore/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter from werkzeug.exceptions import InternalServerError, NotFound @@ -21,7 +22,6 @@ from controllers.console.explore.error import ( from controllers.console.explore.wraps import InstalledAppResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem, SuggestedQuestionsResponse from libs import helper diff --git a/api/controllers/console/explore/trial.py b/api/controllers/console/explore/trial.py index a8d8036f0f..e432574434 100644 --- a/api/controllers/console/explore/trial.py +++ b/api/controllers/console/explore/trial.py @@ -3,6 +3,8 @@ from typing import Any, Literal, cast from flask import request from flask_restx import Resource, fields, marshal, marshal_with +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from sqlalchemy import select from werkzeug.exceptions import Forbidden, InternalServerError, NotFound @@ -42,8 +44,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.app_fields import ( diff --git a/api/controllers/console/explore/workflow.py b/api/controllers/console/explore/workflow.py index 7801cee473..42cafc7193 100644 --- a/api/controllers/console/explore/workflow.py +++ b/api/controllers/console/explore/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel from werkzeug.exceptions import InternalServerError @@ -21,8 +23,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client from libs import helper from libs.login import current_account_with_tenant diff --git a/api/controllers/console/human_input_form.py b/api/controllers/console/human_input_form.py index 7207f7fd1d..e37e78c966 100644 --- a/api/controllers/console/human_input_form.py +++ b/api/controllers/console/human_input_form.py @@ -15,6 +15,7 @@ from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required from controllers.web.error import InvalidArgumentError, NotFoundError from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.apps.message_generator import MessageGenerator from core.app.apps.workflow.app_generator import WorkflowAppGenerator @@ -166,6 +167,7 @@ class ConsoleWorkflowEventsApi(Resource): else: msg_generator = MessageGenerator() + generator: BaseAppGenerator if app.mode == AppMode.ADVANCED_CHAT: generator = AdvancedChatAppGenerator() elif app.mode == AppMode.WORKFLOW: @@ -202,7 +204,7 @@ class ConsoleWorkflowEventsApi(Resource): ) -def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun): +def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun) -> App: query = select(App).where( App.id == workflow_run.app_id, App.tenant_id == workflow_run.tenant_id, diff --git a/api/controllers/console/remote_files.py b/api/controllers/console/remote_files.py index 49162d4dae..551c86fd82 100644 --- a/api/controllers/console/remote_files.py +++ b/api/controllers/console/remote_files.py @@ -2,6 +2,7 @@ import urllib.parse import httpx from flask_restx import Resource +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field import services @@ -13,7 +14,6 @@ from controllers.common.errors import ( ) from controllers.console import console_ns from core.helper import ssrf_proxy -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/workspace/agent_providers.py b/api/controllers/console/workspace/agent_providers.py index e2b504751b..3fdcbc4710 100644 --- a/api/controllers/console/workspace/agent_providers.py +++ b/api/controllers/console/workspace/agent_providers.py @@ -1,8 +1,8 @@ from flask_restx import Resource, fields +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.agent_service import AgentService diff --git a/api/controllers/console/workspace/endpoint.py b/api/controllers/console/workspace/endpoint.py index 538c5fb561..b6b9deb1f9 100644 --- a/api/controllers/console/workspace/endpoint.py +++ b/api/controllers/console/workspace/endpoint.py @@ -2,13 +2,13 @@ from typing import Any from flask import request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginPermissionDeniedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from services.plugin.endpoint_service import EndpointService diff --git a/api/controllers/console/workspace/load_balancing_config.py b/api/controllers/console/workspace/load_balancing_config.py index 0a9e54de99..e4cfca9fa4 100644 --- a/api/controllers/console/workspace/load_balancing_config.py +++ b/api/controllers/console/workspace/load_balancing_config.py @@ -1,12 +1,12 @@ from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from werkzeug.exceptions import Forbidden from controllers.common.schema import register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError from libs.login import current_account_with_tenant, login_required from models import TenantAccountRole from services.model_load_balancing_service import ModelLoadBalancingService diff --git a/api/controllers/console/workspace/model_providers.py b/api/controllers/console/workspace/model_providers.py index db3b02ae94..8e0aefc9e3 100644 --- a/api/controllers/console/workspace/model_providers.py +++ b/api/controllers/console/workspace/model_providers.py @@ -3,13 +3,13 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.billing_service import BillingService diff --git a/api/controllers/console/workspace/models.py b/api/controllers/console/workspace/models.py index d7eceb656c..2ec1a9435a 100644 --- a/api/controllers/console/workspace/models.py +++ b/api/controllers/console/workspace/models.py @@ -3,14 +3,14 @@ from typing import Any, cast from flask import request from flask_restx import Resource +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from controllers.common.schema import register_enum_models, register_schema_models from controllers.console import console_ns from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.helper import uuid_value from libs.login import current_account_with_tenant, login_required from services.model_load_balancing_service import ModelLoadBalancingService @@ -282,14 +282,18 @@ class ModelProviderModelCredentialApi(Resource): ) if args.config_from == "predefined-model": - available_credentials = model_provider_service.provider_manager.get_provider_available_credentials( - tenant_id=tenant_id, provider_name=provider + available_credentials = model_provider_service.get_provider_available_credentials( + tenant_id=tenant_id, + provider=provider, ) else: # Normalize model_type to the origin value stored in DB (e.g., "text-generation" for LLM) normalized_model_type = args.model_type.to_origin_model_type() - available_credentials = model_provider_service.provider_manager.get_provider_model_available_credentials( - tenant_id=tenant_id, provider_name=provider, model_type=normalized_model_type, model_name=args.model + available_credentials = model_provider_service.get_provider_model_available_credentials( + tenant_id=tenant_id, + provider=provider, + model_type=normalized_model_type, + model=args.model, ) return jsonable_encoder( diff --git a/api/controllers/console/workspace/plugin.py b/api/controllers/console/workspace/plugin.py index ee537367c7..aa674a63b3 100644 --- a/api/controllers/console/workspace/plugin.py +++ b/api/controllers/console/workspace/plugin.py @@ -4,6 +4,7 @@ from typing import Any, Literal from flask import request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field from werkzeug.datastructures import FileStorage from werkzeug.exceptions import Forbidden @@ -14,7 +15,6 @@ from controllers.console import console_ns from controllers.console.workspace import plugin_permission_required from controllers.console.wraps import account_initialization_required, is_admin_or_owner_required, setup_required from core.plugin.impl.exc import PluginDaemonClientSideError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from libs.login import current_account_with_tenant, login_required from models.account import TenantPluginAutoUpgradeStrategy, TenantPluginPermission from services.plugin.plugin_auto_upgrade_service import PluginAutoUpgradeService @@ -200,7 +200,7 @@ class PluginDebuggingKeyApi(Resource): "port": dify_config.PLUGIN_REMOTE_INSTALL_PORT, } except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/list") @@ -215,7 +215,7 @@ class PluginListApi(Resource): try: plugins_with_total = PluginService.list_with_total(tenant_id, args.page, args.page_size) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"plugins": plugins_with_total.list, "total": plugins_with_total.total}) @@ -232,7 +232,7 @@ class PluginListLatestVersionsApi(Resource): try: versions = PluginService.list_latest_versions(args.plugin_ids) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"versions": versions}) @@ -251,7 +251,7 @@ class PluginListInstallationsFromIdsApi(Resource): try: plugins = PluginService.list_installations_from_ids(tenant_id, args.plugin_ids) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"plugins": plugins}) @@ -266,7 +266,7 @@ class PluginIconApi(Resource): try: icon_bytes, mimetype = PluginService.get_asset(args.tenant_id, args.filename) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 icon_cache_max_age = dify_config.TOOL_ICON_CACHE_MAX_AGE return send_file(io.BytesIO(icon_bytes), mimetype=mimetype, max_age=icon_cache_max_age) @@ -286,7 +286,7 @@ class PluginAssetApi(Resource): binary = PluginService.extract_asset(tenant_id, args.plugin_unique_identifier, args.file_name) return send_file(io.BytesIO(binary), mimetype="application/octet-stream") except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upload/pkg") @@ -303,7 +303,7 @@ class PluginUploadFromPkgApi(Resource): try: response = PluginService.upload_pkg(tenant_id, content) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -323,7 +323,7 @@ class PluginUploadFromGithubApi(Resource): try: response = PluginService.upload_pkg_from_github(tenant_id, args.repo, args.version, args.package) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -361,7 +361,7 @@ class PluginInstallFromPkgApi(Resource): try: response = PluginService.install_from_local_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -387,7 +387,7 @@ class PluginInstallFromGithubApi(Resource): args.package, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -407,7 +407,7 @@ class PluginInstallFromMarketplaceApi(Resource): try: response = PluginService.install_from_marketplace_pkg(tenant_id, args.plugin_unique_identifiers) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder(response) @@ -433,7 +433,7 @@ class PluginFetchMarketplacePkgApi(Resource): } ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/fetch-manifest") @@ -453,7 +453,7 @@ class PluginFetchManifestApi(Resource): {"manifest": PluginService.fetch_plugin_manifest(tenant_id, args.plugin_unique_identifier).model_dump()} ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks") @@ -471,7 +471,7 @@ class PluginFetchInstallTasksApi(Resource): try: return jsonable_encoder({"tasks": PluginService.fetch_install_tasks(tenant_id, args.page, args.page_size)}) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks/") @@ -486,7 +486,7 @@ class PluginFetchInstallTaskApi(Resource): try: return jsonable_encoder({"task": PluginService.fetch_install_task(tenant_id, task_id)}) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks//delete") @@ -501,7 +501,7 @@ class PluginDeleteInstallTaskApi(Resource): try: return {"success": PluginService.delete_install_task(tenant_id, task_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks/delete_all") @@ -516,7 +516,7 @@ class PluginDeleteAllInstallTaskItemsApi(Resource): try: return {"success": PluginService.delete_all_install_task_items(tenant_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/tasks//delete/") @@ -531,7 +531,7 @@ class PluginDeleteInstallTaskItemApi(Resource): try: return {"success": PluginService.delete_install_task_item(tenant_id, task_id, identifier)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upgrade/marketplace") @@ -553,7 +553,7 @@ class PluginUpgradeFromMarketplaceApi(Resource): ) ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/upgrade/github") @@ -580,7 +580,7 @@ class PluginUpgradeFromGithubApi(Resource): ) ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/uninstall") @@ -598,7 +598,7 @@ class PluginUninstallApi(Resource): try: return {"success": PluginService.uninstall(tenant_id, args.plugin_installation_id)} except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 @console_ns.route("/workspaces/current/plugin/permission/change") @@ -674,7 +674,7 @@ class PluginFetchDynamicSelectOptionsApi(Resource): provider_type=args.provider_type, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"options": options}) @@ -705,7 +705,7 @@ class PluginFetchDynamicSelectOptionsWithCredentialsApi(Resource): credentials=args.credentials, ) except PluginDaemonClientSideError as e: - raise ValueError(e) + return {"code": "plugin_error", "message": e.description}, 400 return jsonable_encoder({"options": options}) diff --git a/api/controllers/console/workspace/tool_providers.py b/api/controllers/console/workspace/tool_providers.py index b38f05795a..02eb0adc94 100644 --- a/api/controllers/console/workspace/tool_providers.py +++ b/api/controllers/console/workspace/tool_providers.py @@ -5,6 +5,7 @@ from urllib.parse import urlparse from flask import make_response, redirect, request, send_file from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, HttpUrl, field_validator, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import Forbidden @@ -26,7 +27,6 @@ from core.mcp.mcp_client import MCPClient from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.tools.entities.tool_entities import ApiProviderSchemaType, WorkflowToolParameterConfiguration -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from libs.helper import alphanumeric, uuid_value from libs.login import current_account_with_tenant, login_required diff --git a/api/controllers/console/workspace/trigger_providers.py b/api/controllers/console/workspace/trigger_providers.py index ad78d2a623..265b6ecd9a 100644 --- a/api/controllers/console/workspace/trigger_providers.py +++ b/api/controllers/console/workspace/trigger_providers.py @@ -3,6 +3,7 @@ from typing import Any from flask import make_response, redirect, request from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, model_validator from sqlalchemy.orm import Session from werkzeug.exceptions import BadRequest, Forbidden @@ -14,7 +15,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.oauth import OAuthHandler from core.trigger.entities.entities import SubscriptionBuilderUpdater from core.trigger.trigger_manager import TriggerManager -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from libs.login import current_user, login_required from models.account import Account diff --git a/api/controllers/files/tool_files.py b/api/controllers/files/tool_files.py index 9e3fb3a90b..2f1e2f28bd 100644 --- a/api/controllers/files/tool_files.py +++ b/api/controllers/files/tool_files.py @@ -70,22 +70,25 @@ class ToolFileApi(Resource): except Exception: raise UnsupportedFileTypeError() + mime_type = tool_file.mime_type + filename = tool_file.filename + response = Response( stream, - mimetype=tool_file.mimetype, + mimetype=mime_type, direct_passthrough=True, headers={}, ) if tool_file.size > 0: response.headers["Content-Length"] = str(tool_file.size) - if args.as_attachment: - encoded_filename = quote(tool_file.name) + if args.as_attachment and filename: + encoded_filename = quote(filename) response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}" enforce_download_for_html( response, - mime_type=tool_file.mimetype, - filename=tool_file.name, + mime_type=mime_type, + filename=filename, extension=extension, ) diff --git a/api/controllers/files/upload.py b/api/controllers/files/upload.py index 52690a12e1..ed3278a28b 100644 --- a/api/controllers/files/upload.py +++ b/api/controllers/files/upload.py @@ -7,8 +7,8 @@ from pydantic import BaseModel, Field from werkzeug.exceptions import Forbidden import services +from core.tools.signature import verify_plugin_file_signature from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.helpers import verify_plugin_file_signature from fields.file_fields import FileResponse from ..common.errors import ( diff --git a/api/controllers/inner_api/__init__.py b/api/controllers/inner_api/__init__.py index 74005217ef..b38994f055 100644 --- a/api/controllers/inner_api/__init__.py +++ b/api/controllers/inner_api/__init__.py @@ -16,12 +16,14 @@ api = ExternalApi( inner_api_ns = Namespace("inner_api", description="Internal API operations", path="/") from . import mail as _mail +from .app import dsl as _app_dsl from .plugin import plugin as _plugin from .workspace import workspace as _workspace api.add_namespace(inner_api_ns) __all__ = [ + "_app_dsl", "_mail", "_plugin", "_workspace", diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py b/api/controllers/inner_api/app/__init__.py similarity index 100% rename from api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/__init__.py rename to api/controllers/inner_api/app/__init__.py diff --git a/api/controllers/inner_api/app/dsl.py b/api/controllers/inner_api/app/dsl.py new file mode 100644 index 0000000000..3b673d6e1d --- /dev/null +++ b/api/controllers/inner_api/app/dsl.py @@ -0,0 +1,111 @@ +"""Inner API endpoints for app DSL import/export. + +Called by the enterprise admin-api service. Import requires ``creator_email`` +to attribute the created app; workspace/membership validation is done by the +Go admin-api caller. +""" + +from flask import request +from flask_restx import Resource +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.orm import Session + +from controllers.common.schema import register_schema_model +from controllers.console.wraps import setup_required +from controllers.inner_api import inner_api_ns +from controllers.inner_api.wraps import enterprise_inner_api_only +from extensions.ext_database import db +from models import Account, App +from models.account import AccountStatus +from services.app_dsl_service import AppDslService, ImportMode, ImportStatus + + +class InnerAppDSLImportPayload(BaseModel): + yaml_content: str = Field(description="YAML DSL content") + creator_email: str = Field(description="Email of the workspace member who will own the imported app") + name: str | None = Field(default=None, description="Override app name from DSL") + description: str | None = Field(default=None, description="Override app description from DSL") + + +register_schema_model(inner_api_ns, InnerAppDSLImportPayload) + + +@inner_api_ns.route("/enterprise/workspaces//dsl/import") +class EnterpriseAppDSLImport(Resource): + @setup_required + @enterprise_inner_api_only + @inner_api_ns.doc("enterprise_app_dsl_import") + @inner_api_ns.expect(inner_api_ns.models[InnerAppDSLImportPayload.__name__]) + @inner_api_ns.doc( + responses={ + 200: "Import completed", + 202: "Import pending (DSL version mismatch requires confirmation)", + 400: "Import failed (business error)", + 404: "Creator account not found or inactive", + } + ) + def post(self, workspace_id: str): + """Import a DSL into a workspace on behalf of a specified creator.""" + args = InnerAppDSLImportPayload.model_validate(inner_api_ns.payload or {}) + + account = _get_active_account(args.creator_email) + if account is None: + return {"message": f"account '{args.creator_email}' not found or inactive"}, 404 + + account.set_tenant_id(workspace_id) + + with Session(db.engine) as session: + dsl_service = AppDslService(session) + result = dsl_service.import_app( + account=account, + import_mode=ImportMode.YAML_CONTENT, + yaml_content=args.yaml_content, + name=args.name, + description=args.description, + ) + session.commit() + + if result.status == ImportStatus.FAILED: + return result.model_dump(mode="json"), 400 + if result.status == ImportStatus.PENDING: + return result.model_dump(mode="json"), 202 + return result.model_dump(mode="json"), 200 + + +@inner_api_ns.route("/enterprise/apps//dsl") +class EnterpriseAppDSLExport(Resource): + @setup_required + @enterprise_inner_api_only + @inner_api_ns.doc( + "enterprise_app_dsl_export", + responses={ + 200: "Export successful", + 404: "App not found", + }, + ) + def get(self, app_id: str): + """Export an app's DSL as YAML.""" + include_secret = request.args.get("include_secret", "false").lower() == "true" + + app_model = db.session.get(App, app_id) + if not app_model: + return {"message": "app not found"}, 404 + + data = AppDslService.export_dsl( + app_model=app_model, + include_secret=include_secret, + ) + + return {"data": data}, 200 + + +def _get_active_account(email: str) -> Account | None: + """Look up an active account by email. + + Workspace membership is already validated by the Go admin-api caller. + """ + account = db.session.scalar(select(Account).where(Account.email == email).limit(1)) + if account is None or account.status != AccountStatus.ACTIVE: + return None + return account diff --git a/api/controllers/inner_api/plugin/plugin.py b/api/controllers/inner_api/plugin/plugin.py index 9b8b3950e6..83c8fa02fe 100644 --- a/api/controllers/inner_api/plugin/plugin.py +++ b/api/controllers/inner_api/plugin/plugin.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.console.wraps import setup_required from controllers.inner_api import inner_api_ns @@ -28,8 +29,7 @@ from core.plugin.entities.request import ( RequestRequestUploadFile, ) from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file.helpers import get_signed_file_url_for_plugin -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from core.tools.signature import get_signed_file_url_for_plugin from libs.helper import length_prefixed_response from models import Account, Tenant from models.model import EndUser diff --git a/api/controllers/mcp/mcp.py b/api/controllers/mcp/mcp.py index 9ddaaa315b..3d00f77e79 100644 --- a/api/controllers/mcp/mcp.py +++ b/api/controllers/mcp/mcp.py @@ -2,6 +2,7 @@ from typing import Any, Union from flask import Response from flask_restx import Resource +from graphon.variables.input_entities import VariableEntity from pydantic import BaseModel, Field, ValidationError from sqlalchemy.orm import Session @@ -9,7 +10,6 @@ from controllers.common.schema import register_schema_model from controllers.mcp import mcp_ns from core.mcp import types as mcp_types from core.mcp.server.streamable_http import handle_mcp_request -from dify_graph.variables.input_entities import VariableEntity from extensions.ext_database import db from libs import helper from models.enums import AppMCPServerStatus diff --git a/api/controllers/service_api/app/audio.py b/api/controllers/service_api/app/audio.py index 38d292d0b9..6228cfc25b 100644 --- a/api/controllers/service_api/app/audio.py +++ b/api/controllers/service_api/app/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -21,7 +22,6 @@ from controllers.service_api.app.error import ( ) from controllers.service_api.wraps import FetchUserArg, WhereisUserArg, validate_app_token from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from models.model import App, EndUser from services.audio_service import AudioService from services.errors.audio import ( diff --git a/api/controllers/service_api/app/completion.py b/api/controllers/service_api/app/completion.py index 98f09c44a1..3142e5118e 100644 --- a/api/controllers/service_api/app/completion.py +++ b/api/controllers/service_api/app/completion.py @@ -4,6 +4,7 @@ from uuid import UUID from flask import request from flask_restx import Resource +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -28,7 +29,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import UUIDStrOrEmpty from models.model import App, AppMode, EndUser diff --git a/api/controllers/service_api/app/file_preview.py b/api/controllers/service_api/app/file_preview.py index f853a124ef..5e7847d784 100644 --- a/api/controllers/service_api/app/file_preview.py +++ b/api/controllers/service_api/app/file_preview.py @@ -4,6 +4,7 @@ from urllib.parse import quote from flask import Response, request from flask_restx import Resource from pydantic import BaseModel, Field +from sqlalchemy import select from controllers.common.file_response import enforce_download_for_html from controllers.common.schema import register_schema_model @@ -102,27 +103,27 @@ class FilePreviewApi(Resource): raise FileAccessDeniedError("Invalid file or app identifier") # First, find the MessageFile that references this upload file - message_file = db.session.query(MessageFile).where(MessageFile.upload_file_id == file_id).first() + message_file = db.session.scalar(select(MessageFile).where(MessageFile.upload_file_id == file_id).limit(1)) if not message_file: raise FileNotFoundError("File not found in message context") # Get the message and verify it belongs to the requesting app - message = ( - db.session.query(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).first() + message = db.session.scalar( + select(Message).where(Message.id == message_file.message_id, Message.app_id == app_id).limit(1) ) if not message: raise FileAccessDeniedError("File access denied: not owned by requesting app") # Get the actual upload file record - upload_file = db.session.query(UploadFile).where(UploadFile.id == file_id).first() + upload_file = db.session.get(UploadFile, file_id) if not upload_file: raise FileNotFoundError("Upload file record not found") # Additional security: verify tenant isolation - app = db.session.query(App).where(App.id == app_id).first() + app = db.session.get(App, app_id) if app and upload_file.tenant_id != app.tenant_id: raise FileAccessDeniedError("File access denied: tenant mismatch") diff --git a/api/controllers/service_api/app/site.py b/api/controllers/service_api/app/site.py index 8b47a887bb..bc06e8f386 100644 --- a/api/controllers/service_api/app/site.py +++ b/api/controllers/service_api/app/site.py @@ -1,4 +1,5 @@ from flask_restx import Resource +from sqlalchemy import select from werkzeug.exceptions import Forbidden from controllers.common.fields import Site as SiteResponse @@ -28,7 +29,7 @@ class AppSiteApi(Resource): Returns the site configuration for the application including theme, icons, and text. """ - site = db.session.query(Site).where(Site.app_id == app_model.id).first() + site = db.session.scalar(select(Site).where(Site.app_id == app_model.id).limit(1)) if not site: raise Forbidden() diff --git a/api/controllers/service_api/app/workflow.py b/api/controllers/service_api/app/workflow.py index 35dd22c801..1759075139 100644 --- a/api/controllers/service_api/app/workflow.py +++ b/api/controllers/service_api/app/workflow.py @@ -4,6 +4,9 @@ from typing import Any, Literal from dateutil.parser import isoparse from flask import request from flask_restx import Namespace, Resource, fields +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import BadRequest, InternalServerError, NotFound @@ -27,9 +30,6 @@ from core.errors.error import ( QuotaExceededError, ) from core.helper.trace_id_helper import get_external_trace_id -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_database import db from extensions.ext_redis import redis_client from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model diff --git a/api/controllers/service_api/dataset/dataset.py b/api/controllers/service_api/dataset/dataset.py index 25b6436a71..80205b283b 100644 --- a/api/controllers/service_api/dataset/dataset.py +++ b/api/controllers/service_api/dataset/dataset.py @@ -2,6 +2,7 @@ from typing import Any, Literal, cast from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import Forbidden, NotFound @@ -14,9 +15,8 @@ from controllers.service_api.wraps import ( DatasetApiResource, cloud_edition_billing_rate_limit_check, ) -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.index_processor.constant.index_type import IndexTechniqueType -from dify_graph.model_runtime.entities.model_entities import ModelType from fields.dataset_fields import dataset_detail_fields from fields.tag_fields import DataSetTag from libs.login import current_user @@ -140,10 +140,10 @@ class DatasetListApi(DatasetApiResource): query.page, query.limit, tenant_id, current_user, query.keyword, query.tag_ids, query.include_all ) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) @@ -259,10 +259,10 @@ class DatasetApi(DatasetApiResource): raise Forbidden(str(e)) data = cast(dict[str, Any], marshal(dataset, dataset_detail_fields)) # check embedding setting - provider_manager = ProviderManager() assert isinstance(current_user, Account) cid = current_user.current_tenant_id assert cid is not None + provider_manager = create_plugin_provider_manager(tenant_id=cid) configurations = provider_manager.get_configurations(tenant_id=cid) embedding_models = configurations.get_models(model_type=ModelType.TEXT_EMBEDDING, only_active=True) diff --git a/api/controllers/service_api/dataset/document.py b/api/controllers/service_api/dataset/document.py index d34b4124ae..2c094aa3e6 100644 --- a/api/controllers/service_api/dataset/document.py +++ b/api/controllers/service_api/dataset/document.py @@ -6,7 +6,7 @@ from uuid import UUID from flask import request, send_file from flask_restx import marshal from pydantic import BaseModel, Field, field_validator, model_validator -from sqlalchemy import desc, select +from sqlalchemy import desc, func, select from werkzeug.exceptions import Forbidden, NotFound import services @@ -155,7 +155,9 @@ class DocumentAddByTextApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -238,7 +240,9 @@ class DocumentUpdateByTextApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: UUID, document_id: UUID): """Update document by text.""" payload = DocumentTextUpdate.model_validate(service_api_ns.payload or {}) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == str(dataset_id)).limit(1) + ) args = payload.model_dump(exclude_none=True) if not dataset: raise ValueError("Dataset does not exist.") @@ -315,7 +319,9 @@ class DocumentAddByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id): """Create document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -425,7 +431,9 @@ class DocumentUpdateByFileApi(DatasetApiResource): @cloud_edition_billing_rate_limit_check("knowledge", "dataset") def post(self, tenant_id, dataset_id, document_id): """Update document by upload file.""" - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") @@ -515,7 +523,9 @@ class DocumentListApi(DatasetApiResource): dataset_id = str(dataset_id) tenant_id = str(tenant_id) query_params = DocumentListQuery.model_validate(request.args.to_dict()) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -609,7 +619,9 @@ class DocumentIndexingStatusApi(DatasetApiResource): batch = str(batch) tenant_id = str(tenant_id) # get dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # get documents @@ -619,20 +631,23 @@ class DocumentIndexingStatusApi(DatasetApiResource): documents_status = [] for document in documents: completed_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.completed_at.isnot(None), - DocumentSegment.document_id == str(document.id), - DocumentSegment.status != SegmentStatus.RE_SEGMENT, + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.completed_at.isnot(None), + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) total_segments = ( - db.session.query(DocumentSegment) - .where( - DocumentSegment.document_id == str(document.id), DocumentSegment.status != SegmentStatus.RE_SEGMENT + db.session.scalar( + select(func.count(DocumentSegment.id)).where( + DocumentSegment.document_id == str(document.id), + DocumentSegment.status != SegmentStatus.RE_SEGMENT, + ) ) - .count() + or 0 ) # Create a dictionary with document attributes and additional fields document_dict = { @@ -822,7 +837,9 @@ class DocumentApi(DatasetApiResource): tenant_id = str(tenant_id) # get dataset info - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise ValueError("Dataset does not exist.") diff --git a/api/controllers/service_api/dataset/segment.py b/api/controllers/service_api/dataset/segment.py index 595b01a9f2..b4cc9874b6 100644 --- a/api/controllers/service_api/dataset/segment.py +++ b/api/controllers/service_api/dataset/segment.py @@ -2,7 +2,9 @@ from typing import Any from flask import request from flask_restx import marshal +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field +from sqlalchemy import select from werkzeug.exceptions import NotFound from configs import dify_config @@ -18,7 +20,6 @@ from controllers.service_api.wraps import ( from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from fields.segment_fields import child_chunk_fields, segment_fields from libs.login import current_account_with_tenant @@ -92,7 +93,9 @@ class SegmentApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create single segment.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -106,7 +109,7 @@ class SegmentApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -150,7 +153,9 @@ class SegmentApi(DatasetApiResource): # check dataset page = request.args.get("page", default=1, type=int) limit = request.args.get("limit", default=20, type=int) - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check document @@ -160,7 +165,7 @@ class SegmentApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -220,7 +225,9 @@ class DatasetSegmentApi(DatasetApiResource): def delete(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -254,7 +261,9 @@ class DatasetSegmentApi(DatasetApiResource): def post(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -266,7 +275,7 @@ class DatasetSegmentApi(DatasetApiResource): if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -301,7 +310,9 @@ class DatasetSegmentApi(DatasetApiResource): def get(self, tenant_id: str, dataset_id: str, document_id: str, segment_id: str): _, current_tenant_id = current_account_with_tenant() # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") # check user's model setting @@ -344,7 +355,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Create child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -361,7 +374,7 @@ class ChildChunkApi(DatasetApiResource): # check embedding model setting if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_tenant_id) model_manager.get_model_instance( tenant_id=current_tenant_id, provider=dataset.embedding_model_provider, @@ -402,7 +415,9 @@ class ChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Get child chunks.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -468,7 +483,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Delete child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") @@ -527,7 +544,9 @@ class DatasetChildChunkApi(DatasetApiResource): _, current_tenant_id = current_account_with_tenant() """Update child chunk.""" # check dataset - dataset = db.session.query(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id == dataset_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/service_api/workspace/models.py b/api/controllers/service_api/workspace/models.py index 35aed40a59..c0a6cb0a76 100644 --- a/api/controllers/service_api/workspace/models.py +++ b/api/controllers/service_api/workspace/models.py @@ -1,9 +1,9 @@ from flask_login import current_user from flask_restx import Resource +from graphon.model_runtime.utils.encoders import jsonable_encoder from controllers.service_api import service_api_ns from controllers.service_api.wraps import validate_dataset_token -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from services.model_provider_service import ModelProviderService diff --git a/api/controllers/service_api/wraps.py b/api/controllers/service_api/wraps.py index 7aa5b2f092..1d52b8a737 100644 --- a/api/controllers/service_api/wraps.py +++ b/api/controllers/service_api/wraps.py @@ -9,6 +9,7 @@ from flask import current_app, request from flask_login import user_logged_in from flask_restx import Resource from pydantic import BaseModel +from sqlalchemy import select from werkzeug.exceptions import Forbidden, NotFound, Unauthorized from enums.cloud_plan import CloudPlan @@ -62,7 +63,7 @@ def validate_app_token( def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R: api_token = validate_and_get_api_token("app") - app_model = db.session.query(App).where(App.id == api_token.app_id).first() + app_model = db.session.get(App, api_token.app_id) if not app_model: raise Forbidden("The app no longer exists.") @@ -72,7 +73,7 @@ def validate_app_token( if not app_model.enable_api: raise Forbidden("The app's API service has been disabled.") - tenant = db.session.query(Tenant).where(Tenant.id == app_model.tenant_id).first() + tenant = db.session.get(Tenant, app_model.tenant_id) if tenant is None: raise ValueError("Tenant does not exist.") if tenant.status == TenantStatus.ARCHIVE: @@ -106,8 +107,8 @@ def validate_app_token( else: # For service API without end-user context, ensure an Account is logged in # so services relying on current_account_with_tenant() work correctly. - tenant_owner_info = ( - db.session.query(Tenant, Account) + tenant_owner_info = db.session.execute( + select(Tenant, Account) .join(TenantAccountJoin, Tenant.id == TenantAccountJoin.tenant_id) .join(Account, TenantAccountJoin.account_id == Account.id) .where( @@ -115,8 +116,7 @@ def validate_app_token( TenantAccountJoin.role == "owner", Tenant.status == TenantStatus.NORMAL, ) - .one_or_none() - ) + ).one_or_none() if tenant_owner_info: tenant_model, account = tenant_owner_info @@ -277,29 +277,28 @@ def validate_dataset_token( # Validate dataset if dataset_id is provided if dataset_id: dataset_id = str(dataset_id) - dataset = ( - db.session.query(Dataset) + dataset = db.session.scalar( + select(Dataset) .where( Dataset.id == dataset_id, Dataset.tenant_id == api_token.tenant_id, ) - .first() + .limit(1) ) if not dataset: raise NotFound("Dataset not found.") if not dataset.enable_api: raise Forbidden("Dataset api access is not enabled.") - tenant_account_join = ( - db.session.query(Tenant, TenantAccountJoin) + tenant_account_join = db.session.execute( + select(Tenant, TenantAccountJoin) .where(Tenant.id == api_token.tenant_id) .where(TenantAccountJoin.tenant_id == Tenant.id) .where(TenantAccountJoin.role.in_(["owner"])) .where(Tenant.status == TenantStatus.NORMAL) - .one_or_none() - ) # TODO: only owner information is required, so only one is returned. + ).one_or_none() # TODO: only owner information is required, so only one is returned. if tenant_account_join: tenant, ta = tenant_account_join - account = db.session.query(Account).where(Account.id == ta.account_id).first() + account = db.session.get(Account, ta.account_id) # Login admin if account: account.current_tenant = tenant @@ -360,7 +359,9 @@ class DatasetApiResource(Resource): method_decorators = [validate_dataset_token] def get_dataset(self, dataset_id: str, tenant_id: str) -> Dataset: - dataset = db.session.query(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).first() + dataset = db.session.scalar( + select(Dataset).where(Dataset.id == dataset_id, Dataset.tenant_id == tenant_id).limit(1) + ) if not dataset: raise NotFound("Dataset not found.") diff --git a/api/controllers/web/audio.py b/api/controllers/web/audio.py index 2b8f752668..9ba1dc4a3a 100644 --- a/api/controllers/web/audio.py +++ b/api/controllers/web/audio.py @@ -2,6 +2,7 @@ import logging from flask import request from flask_restx import fields, marshal_with +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, field_validator from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.web.error import ( ) from controllers.web.wraps import WebApiResource from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from libs.helper import uuid_value from models.model import App from services.audio_service import AudioService diff --git a/api/controllers/web/completion.py b/api/controllers/web/completion.py index 8634c1f43c..e37f9af5f0 100644 --- a/api/controllers/web/completion.py +++ b/api/controllers/web/completion.py @@ -1,6 +1,7 @@ import logging from typing import Any, Literal +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from libs import helper from libs.helper import uuid_value from models.model import AppMode diff --git a/api/controllers/web/message.py b/api/controllers/web/message.py index aa56292614..c5505dd60d 100644 --- a/api/controllers/web/message.py +++ b/api/controllers/web/message.py @@ -2,6 +2,7 @@ import logging from typing import Literal from flask import request +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field, TypeAdapter, field_validator from werkzeug.exceptions import InternalServerError, NotFound @@ -20,7 +21,6 @@ from controllers.web.error import ( from controllers.web.wraps import WebApiResource from core.app.entities.app_invoke_entities import InvokeFrom from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from fields.conversation_fields import ResultResponse from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem from libs import helper diff --git a/api/controllers/web/remote_files.py b/api/controllers/web/remote_files.py index 6a93ef6748..38aeccc642 100644 --- a/api/controllers/web/remote_files.py +++ b/api/controllers/web/remote_files.py @@ -1,6 +1,7 @@ import urllib.parse import httpx +from graphon.file import helpers as file_helpers from pydantic import BaseModel, Field, HttpUrl import services @@ -11,7 +12,6 @@ from controllers.common.errors import ( UnsupportedFileTypeError, ) from core.helper import ssrf_proxy -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from fields.file_fields import FileWithSignedUrl, RemoteFileInfo from services.file_service import FileService diff --git a/api/controllers/web/workflow.py b/api/controllers/web/workflow.py index 508d1a756a..7f5521f9f5 100644 --- a/api/controllers/web/workflow.py +++ b/api/controllers/web/workflow.py @@ -1,6 +1,8 @@ import logging from typing import Any +from graphon.graph_engine.manager import GraphEngineManager +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import BaseModel, Field from werkzeug.exceptions import InternalServerError @@ -22,8 +24,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.graph_engine.manager import GraphEngineManager -from dify_graph.model_runtime.errors.invoke import InvokeError from extensions.ext_redis import redis_client from libs import helper from models.model import App, AppMode, EndUser diff --git a/api/core/agent/base_agent_runner.py b/api/core/agent/base_agent_runner.py index 1bdc8df813..06c746990d 100644 --- a/api/core/agent/base_agent_runner.py +++ b/api/core/agent/base_agent_runner.py @@ -4,7 +4,21 @@ import uuid from decimal import Decimal from typing import Union, cast -from sqlalchemy import select +from graphon.file import file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + LLMUsage, + PromptMessage, + PromptMessageTool, + SystemPromptMessage, + TextPromptMessageContent, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.model_entities import ModelFeature +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from sqlalchemy import func, select from core.agent.entities import AgentEntity, AgentToolEntity from core.app.app_config.features.file_upload.manager import FileUploadConfigManager @@ -15,6 +29,7 @@ from core.app.entities.app_invoke_entities import ( AgentChatAppGenerateEntity, ModelConfigWithCredentialsEntity, ) +from core.app.file_access import DatabaseFileAccessController from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler from core.memory.token_buffer_memory import TokenBufferMemory @@ -26,26 +41,13 @@ from core.tools.entities.tool_entities import ( ) from core.tools.tool_manager import ToolManager from core.tools.utils.dataset_retriever_tool import DatasetRetrieverTool -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( - AssistantPromptMessage, - LLMUsage, - PromptMessage, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from factories import file_factory from models.enums import CreatorUserRole from models.model import Conversation, Message, MessageAgentThought, MessageFile logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class BaseAgentRunner(AppRunner): @@ -102,11 +104,14 @@ class BaseAgentRunner(AppRunner): ) # get how many agent thoughts have been created self.agent_thought_count = ( - db.session.query(MessageAgentThought) - .where( - MessageAgentThought.message_id == self.message.id, + db.session.scalar( + select(func.count()) + .select_from(MessageAgentThought) + .where( + MessageAgentThought.message_id == self.message.id, + ) ) - .count() + or 0 ) db.session.close() @@ -138,6 +143,7 @@ class BaseAgentRunner(AppRunner): tenant_id=self.tenant_id, app_id=self.app_config.app_id, agent_tool=tool, + user_id=self.user_id, invoke_from=self.application_generate_entity.invoke_from, ) assert tool_entity.entity.description @@ -524,7 +530,10 @@ class BaseAgentRunner(AppRunner): image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW file_objs = file_factory.build_from_message_files( - message_files=files, tenant_id=self.tenant_id, config=file_extra_config + message_files=files, + tenant_id=self.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) if not file_objs: return UserPromptMessage(content=message.query) diff --git a/api/core/agent/cot_agent_runner.py b/api/core/agent/cot_agent_runner.py index 9271ed10bd..11e2aa062d 100644 --- a/api/core/agent/cot_agent_runner.py +++ b/api/core/agent/cot_agent_runner.py @@ -4,6 +4,15 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping, Sequence from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageTool, + ToolPromptMessage, + UserPromptMessage, +) + from core.agent.base_agent_runner import BaseAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError @@ -15,14 +24,6 @@ from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransfo from core.tools.__base.tool import Tool from core.tools.entities.tool_entities import ToolInvokeMeta from core.tools.tool_engine import ToolEngine -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) from models.model import Message logger = logging.getLogger(__name__) @@ -122,7 +123,6 @@ class CotAgentRunner(BaseAgentRunner, ABC): tools=[], stop=app_generate_entity.model_conf.stop, stream=True, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/cot_chat_agent_runner.py b/api/core/agent/cot_chat_agent_runner.py index 89451a0498..a4c438e929 100644 --- a/api/core/agent/cot_chat_agent_runner.py +++ b/api/core/agent/cot_chat_agent_runner.py @@ -1,16 +1,17 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, PromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.agent.cot_agent_runner import CotAgentRunner class CotChatAgentRunner(CotAgentRunner): diff --git a/api/core/agent/cot_completion_agent_runner.py b/api/core/agent/cot_completion_agent_runner.py index 3023b9bc4d..d4c52a8eb1 100644 --- a/api/core/agent/cot_completion_agent_runner.py +++ b/api/core/agent/cot_completion_agent_runner.py @@ -1,13 +1,14 @@ import json -from core.agent.cot_agent_runner import CotAgentRunner -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.agent.cot_agent_runner import CotAgentRunner class CotCompletionAgentRunner(CotAgentRunner): diff --git a/api/core/agent/fc_agent_runner.py b/api/core/agent/fc_agent_runner.py index 5e13a13b21..fdffde85d0 100644 --- a/api/core/agent/fc_agent_runner.py +++ b/api/core/agent/fc_agent_runner.py @@ -4,15 +4,8 @@ from collections.abc import Generator from copy import deepcopy from typing import Any, Union -from core.agent.base_agent_runner import BaseAgentRunner -from core.agent.errors import AgentMaxIterationError -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent -from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from core.tools.entities.tool_entities import ToolInvokeMeta -from core.tools.tool_engine import ToolEngine -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, LLMResult, LLMResultChunk, @@ -25,7 +18,15 @@ from dify_graph.model_runtime.entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes + +from core.agent.base_agent_runner import BaseAgentRunner +from core.agent.errors import AgentMaxIterationError +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent +from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform +from core.tools.entities.tool_entities import ToolInvokeMeta +from core.tools.tool_engine import ToolEngine from models.model import Message logger = logging.getLogger(__name__) @@ -96,7 +97,6 @@ class FunctionCallAgentRunner(BaseAgentRunner): tools=prompt_messages_tools, stop=app_generate_entity.model_conf.stop, stream=self.stream_tool_call, - user=self.user_id, callbacks=[], ) diff --git a/api/core/agent/output_parser/cot_output_parser.py b/api/core/agent/output_parser/cot_output_parser.py index 82676f1ebd..46c1f1230d 100644 --- a/api/core/agent/output_parser/cot_output_parser.py +++ b/api/core/agent/output_parser/cot_output_parser.py @@ -3,8 +3,9 @@ import re from collections.abc import Generator from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResultChunk + from core.agent.entities import AgentScratchpadUnit -from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk class CotAgentOutputParser: diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py index 558b6e69a0..b7dd55632e 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/converter.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/converter.py @@ -1,13 +1,14 @@ from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.app_config.entities import EasyUIBasedAppConfig from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.model_entities import ModelStatus from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager class ModelConfigConverter: @@ -21,7 +22,7 @@ class ModelConfigConverter: """ model_config = app_config.model - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=app_config.tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id=app_config.tenant_id, provider=model_config.provider, model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py index 0929f52e33..5cc385c378 100644 --- a/api/core/app/app_config/easy_ui_based_app/model_config/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/model_config/manager.py @@ -1,10 +1,10 @@ from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.app.app_config.entities import ModelConfigEntity -from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from models.model import AppModelConfigDict from models.provider_ids import ModelProviderID @@ -54,9 +54,12 @@ class ModelConfigManager: if not isinstance(config["model"], dict): raise ValueError("model must be of object type") + # Keep provider discovery and provider-backed model listing on the same + # request-scoped runtime so caller scope and provider caches stay aligned. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + # model.provider - model_provider_factory = ModelProviderFactory(tenant_id) - provider_entities = model_provider_factory.get_providers() + provider_entities = assembly.model_provider_factory.get_providers() model_provider_names = [provider.provider for provider in provider_entities] if "provider" not in config["model"]: raise ValueError(f"model.provider is required and must be in {str(model_provider_names)}") @@ -71,8 +74,7 @@ class ModelConfigManager: if "name" not in config["model"]: raise ValueError("model.name is required") - provider_manager = ProviderManager() - models = provider_manager.get_configurations(tenant_id).get_models( + models = assembly.provider_manager.get_configurations(tenant_id).get_models( provider=config["model"]["provider"], model_type=ModelType.LLM ) diff --git a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py index b7073898d6..76196e7034 100644 --- a/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/prompt_template/manager.py @@ -1,5 +1,7 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessageRole + from core.app.app_config.entities import ( AdvancedChatMessageEntity, AdvancedChatPromptTemplateEntity, @@ -7,7 +9,6 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.simple_prompt_transform import ModelMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole from models.model import AppMode, AppModelConfigDict diff --git a/api/core/app/app_config/easy_ui_based_app/variables/manager.py b/api/core/app/app_config/easy_ui_based_app/variables/manager.py index 8de1224a89..f0b71c5801 100644 --- a/api/core/app/app_config/easy_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/easy_ui_based_app/variables/manager.py @@ -1,9 +1,10 @@ import re from typing import cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from core.app.app_config.entities import ExternalDataVariableEntity from core.external_data_tool.factory import ExternalDataToolFactory -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import AppModelConfigDict _ALLOWED_VARIABLE_ENTITY_TYPE = frozenset( diff --git a/api/core/app/app_config/entities.py b/api/core/app/app_config/entities.py index 95ea70bc40..536617edba 100644 --- a/api/core/app/app_config/entities.py +++ b/api/core/app/app_config/entities.py @@ -2,13 +2,13 @@ from collections.abc import Sequence from enum import StrEnum, auto from typing import Any, Literal +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity as WorkflowVariableEntity from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from dify_graph.file import FileUploadConfig -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.variables.input_entities import VariableEntity as WorkflowVariableEntity from models.model import AppMode diff --git a/api/core/app/app_config/features/file_upload/manager.py b/api/core/app/app_config/features/file_upload/manager.py index 0c4266fbeb..e96517c426 100644 --- a/api/core/app/app_config/features/file_upload/manager.py +++ b/api/core/app/app_config/features/file_upload/manager.py @@ -1,8 +1,9 @@ from collections.abc import Mapping from typing import Any +from graphon.file import FileUploadConfig + from constants import DEFAULT_FILE_NUMBER_LIMITS -from dify_graph.file import FileUploadConfig class FileUploadConfigManager: diff --git a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py index d2a9a73380..62e0c31d1a 100644 --- a/api/core/app/app_config/workflow_ui_based_app/variables/manager.py +++ b/api/core/app/app_config/workflow_ui_based_app/variables/manager.py @@ -1,7 +1,8 @@ import re +from graphon.variables.input_entities import VariableEntity + from core.app.app_config.entities import RagPipelineVariableEntity -from dify_graph.variables.input_entities import VariableEntity from models.workflow import Workflow diff --git a/api/core/app/apps/advanced_chat/app_generator.py b/api/core/app/apps/advanced_chat/app_generator.py index 5d974335ff..aa2b65766f 100644 --- a/api/core/app/apps/advanced_chat/app_generator.py +++ b/api/core/app/apps/advanced_chat/app_generator.py @@ -5,7 +5,7 @@ import logging import threading import uuid from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload +from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app from pydantic import ValidationError @@ -18,12 +18,23 @@ from constants import UUID_NIL if TYPE_CHECKING: from controllers.console.app.workflow import LoopNodeRunPayload +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager @@ -34,20 +45,11 @@ from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.prompt.utils.get_thread_messages_length import get_thread_messages_length from core.repositories import DifyCoreRepositoryFactory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import ( - DraftVariableSaverFactory, -) -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom -from models.base import Base from models.enums import WorkflowRunTriggeredFrom from services.conversation_service import ConversationService from services.workflow_draft_variable_service import ( @@ -150,85 +152,87 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id ) - else: - file_objs = [] - # convert to app config - app_config = AdvancedChatAppConfigManager.get_app_config(app_model=app_model, workflow=workflow) + if invoke_from == InvokeFrom.DEBUGGER: + # always enable retriever resource in debugger mode + app_config.additional_features.show_retrieve_source = True # type: ignore - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # init application generate entity + application_generate_entity = AdvancedChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + workflow_run_id=str(workflow_run_id), + ) + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) - if invoke_from == InvokeFrom.DEBUGGER: - # always enable retriever resource in debugger mode - app_config.additional_features.show_retrieve_source = True # type: ignore + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) - # init application generate entity - application_generate_entity = AdvancedChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - workflow_run_id=str(workflow_run_id), - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) - - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) - - return self._generate( - workflow=workflow, - user=user, - invoke_from=invoke_from, - application_generate_entity=application_generate_entity, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - conversation=conversation, - stream=streaming, - pause_state_config=pause_state_config, - ) + return self._generate( + workflow=workflow, + user=user, + invoke_from=invoke_from, + application_generate_entity=application_generate_entity, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + conversation=conversation, + stream=streaming, + pause_state_config=pause_state_config, + ) def resume( self, @@ -460,94 +464,91 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): :param conversation: conversation :param stream: is stream """ - is_first_conversation = conversation is None + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + is_first_conversation = conversation is None - if conversation is not None and message is not None: - pass - else: - conversation, message = self._init_generate_records(application_generate_entity, conversation) + if conversation is not None and message is not None: + pass + else: + conversation, message = self._init_generate_records(application_generate_entity, conversation) - if is_first_conversation: - # update conversation features - conversation.override_model_configs = workflow.features - db.session.commit() - db.session.refresh(conversation) + if is_first_conversation: + # update conversation features + conversation.override_model_configs = workflow.features + db.session.commit() + db.session.refresh(conversation) - # get conversation dialogue count - # NOTE: dialogue_count should not start from 0, - # because during the first conversation, dialogue_count should be 1. - self._dialogue_count = get_thread_messages_length(conversation.id) + 1 + # get conversation dialogue count + # NOTE: dialogue_count should not start from 0, + # because during the first conversation, dialogue_count should be 1. + self._dialogue_count = get_thread_messages_length(conversation.id) + 1 - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - "context": context, - "variable_loader": variable_loader, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + "context": context, + "variable_loader": variable_loader, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - # release database connection, because the following new thread operations may take a long time - with Session(bind=db.engine, expire_on_commit=False) as session: - workflow = _refresh_model(session, workflow) - message = _refresh_model(session, message) - # workflow_ = session.get(Workflow, workflow.id) - # assert workflow_ is not None - # workflow = workflow_ - # message_ = session.get(Message, message.id) - # assert message_ is not None - # message = message_ - # db.session.refresh(workflow) - # db.session.refresh(message) - # db.session.refresh(user) - db.session.close() + worker_thread.start() - # return response or stream generator - response = self._handle_advanced_chat_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), - ) + # Capture the scalar fields needed by the response pipeline before + # releasing the request-scoped SQLAlchemy session. + workflow_snapshot = WorkflowSnapshot.from_workflow(workflow) + conversation_snapshot = ConversationSnapshot.from_conversation(conversation) + message_snapshot = MessageSnapshot.from_message(message) + db.session.close() - return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_advanced_chat_response( + application_generate_entity=application_generate_entity, + workflow=workflow_snapshot, + queue_manager=queue_manager, + conversation=conversation_snapshot, + message=message_snapshot, + user=user, + stream=stream, + draft_var_saver_factory=self._get_draft_var_saver_factory(invoke_from, account=user), + ) + + return AdvancedChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, @@ -648,10 +649,10 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): self, *, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], draft_var_saver_factory: DraftVariableSaverFactory, stream: bool = False, @@ -688,13 +689,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator): else: logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id) raise e - - -_T = TypeVar("_T", bound=Base) - - -def _refresh_model(session, model: _T) -> _T: - with Session(bind=db.engine, expire_on_commit=False) as session: - detach_model = session.get(type(model), model.id) - assert detach_model is not None - return detach_model diff --git a/api/core/app/apps/advanced_chat/app_runner.py b/api/core/app/apps/advanced_chat/app_runner.py index 66037696af..a884a1c7f9 100644 --- a/api/core/app/apps/advanced_chat/app_runner.py +++ b/api/core/app/apps/advanced_chat/app_runner.py @@ -3,6 +3,12 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import Variable from sqlalchemy import select from sqlalchemy.orm import Session @@ -25,16 +31,15 @@ from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, Workfl from core.db.session_factory import session_factory from core.moderation.base import ModerationError from core.moderation.input_moderation import InputModeration +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import ( + build_bootstrap_variables, + build_system_variables, + system_variables_to_mapping, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import WorkflowType -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables.variables import Variable from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span @@ -90,7 +95,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): app_config = self.application_generate_entity.app_config app_config = cast(AdvancedChatAppConfig, app_config) - system_inputs = SystemVariable( + system_inputs = build_system_variables( query=self.application_generate_entity.query, files=self.application_generate_entity.files, conversation_id=self.conversation.id, @@ -132,6 +137,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs @@ -150,7 +156,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): self.application_generate_entity.inputs = new_inputs self.application_generate_entity.query = new_query - system_inputs.query = new_query + system_inputs = build_system_variables( + system_variables_to_mapping(system_inputs), + query=new_query, + ) # annotation reply if self.handle_annotation_reply( @@ -166,14 +175,17 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): # Create a variable pool. # init variable pool - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=new_inputs, - environment_variables=self._workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=conversation_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + conversation_variables=conversation_variables, + ), ) + root_node_id = get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=new_inputs) # init graph graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) @@ -185,6 +197,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, + root_node_id=root_node_id, ) db.session.close() diff --git a/api/core/app/apps/advanced_chat/generate_task_pipeline.py b/api/core/app/apps/advanced_chat/generate_task_pipeline.py index f7b5030d33..5203de225c 100644 --- a/api/core/app/apps/advanced_chat/generate_task_pipeline.py +++ b/api/core/app/apps/advanced_chat/generate_task_pipeline.py @@ -4,9 +4,17 @@ import re import time from collections.abc import Callable, Generator, Mapping from contextlib import contextmanager +from dataclasses import dataclass +from datetime import datetime from threading import Thread from typing import Any, Union +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session @@ -14,6 +22,7 @@ from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import ( AdvancedChatAppGenerateEntity, InvokeFrom, @@ -65,24 +74,66 @@ from core.app.task_pipeline.message_cycle_manager import MessageCycleManager from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, Conversation, EndUser, Message, MessageFile from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus from models.execution_extra_content import HumanInputContent +from models.model import AppMode from models.workflow import Workflow logger = logging.getLogger(__name__) +@dataclass(frozen=True, slots=True) +class WorkflowSnapshot: + id: str + tenant_id: str + features_dict: Mapping[str, Any] + + @classmethod + def from_workflow(cls, workflow: Workflow) -> "WorkflowSnapshot": + return cls( + id=workflow.id, + tenant_id=workflow.tenant_id, + features_dict=dict(workflow.features_dict), + ) + + +@dataclass(frozen=True, slots=True) +class ConversationSnapshot: + id: str + mode: AppMode + + @classmethod + def from_conversation(cls, conversation: Conversation) -> "ConversationSnapshot": + return cls( + id=conversation.id, + mode=conversation.mode, + ) + + +@dataclass(frozen=True, slots=True) +class MessageSnapshot: + id: str + query: str + created_at: datetime + status: MessageStatus + answer: str + + @classmethod + def from_message(cls, message: Message) -> "MessageSnapshot": + return cls( + id=message.id, + query=message.query, + created_at=message.created_at, + status=message.status, + answer=message.answer, + ) + + class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): """ AdvancedChatAppGenerateTaskPipeline is a class that generate stream output and state management for Application. @@ -91,10 +142,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def __init__( self, application_generate_entity: AdvancedChatAppGenerateEntity, - workflow: Workflow, + workflow: WorkflowSnapshot, queue_manager: AppQueueManager, - conversation: Conversation, - message: Message, + conversation: ConversationSnapshot, + message: MessageSnapshot, user: Union[Account, EndUser], stream: bool, dialogue_count: int, @@ -117,7 +168,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): else: raise NotImplementedError(f"User type not supported: {type(user)}") - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( query=message.query, files=application_generate_entity.files, conversation_id=conversation.id, @@ -155,7 +206,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._message_saved_on_pause = False self._seed_graph_runtime_state_from_queue_manager() - def _seed_task_state_from_message(self, message: Message) -> None: + def _seed_task_state_from_message(self, message: MessageSnapshot) -> None: if message.status == MessageStatus.PAUSED and message.answer: self._task_state.answer = message.answer @@ -741,8 +792,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): def _load_human_input_form_id(self, *, node_id: str) -> str | None: form_repository = HumanInputFormRepositoryImpl( tenant_id=self._workflow_tenant_id, + workflow_execution_id=self._workflow_run_id, ) - form = form_repository.get_form(self._workflow_run_id, node_id) + form = form_repository.get_form(node_id) if form is None: return None return form.id @@ -933,21 +985,23 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): metadata = self._task_state.metadata.model_dump() message.message_metadata = json.dumps(jsonable_encoder(metadata)) - message_files = [ - MessageFile( - message_id=message.id, - type=file["type"], - transfer_method=file["transfer_method"], - url=file["remote_url"], - belongs_to=MessageFileBelongsTo.ASSISTANT, - upload_file_id=file["related_id"], - created_by_role=CreatorUserRole.ACCOUNT - if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} - else CreatorUserRole.END_USER, - created_by=message.from_account_id or message.from_end_user_id or "", + message_files: list[MessageFile] = [] + for file in self._recorded_files: + reference = file.get("reference") or file.get("related_id") + message_files.append( + MessageFile( + message_id=message.id, + type=file["type"], + transfer_method=file["transfer_method"], + url=file["remote_url"], + belongs_to=MessageFileBelongsTo.ASSISTANT, + upload_file_id=resolve_file_record_id(reference if isinstance(reference, str) else None), + created_by_role=CreatorUserRole.ACCOUNT + if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER} + else CreatorUserRole.END_USER, + created_by=message.from_account_id or message.from_end_user_id or "", + ) ) - for file in self._recorded_files - ] session.add_all(message_files) def _seed_graph_runtime_state_from_queue_manager(self) -> None: @@ -1003,13 +1057,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport): return message def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/agent_chat/app_generator.py b/api/core/app/apps/agent_chat/app_generator.py index 76a067d7b6..bb258af4c1 100644 --- a/api/core/app/apps/agent_chat/app_generator.py +++ b/api/core/app/apps/agent_chat/app_generator.py @@ -6,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -21,7 +22,6 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import AgentChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts @@ -129,89 +129,93 @@ class AgentChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args.get("files") or [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args.get("files") or [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = AgentChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = AgentChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) + # get tracing instance + trace_manager = TraceQueueManager(app_model.id, user.id if isinstance(user, Account) else user.session_id) - # init application generate entity - application_generate_entity = AgentChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras=extras, - call_depth=0, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = AgentChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras=extras, + call_depth=0, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, + message_id=message.id, + ) - # new thread with request context and contextvars - context = contextvars.copy_context() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "context": context, - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "conversation_id": conversation.id, - "message_id": message.id, - }, - ) + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "context": context, + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "conversation_id": conversation.id, + "message_id": message.id, + }, + ) - worker_thread.start() + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) - return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + return AgentChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/agent_chat/app_runner.py b/api/core/app/apps/agent_chat/app_runner.py index a81da2e91c..a20d3f3c38 100644 --- a/api/core/app/apps/agent_chat/app_runner.py +++ b/api/core/app/apps/agent_chat/app_runner.py @@ -1,6 +1,9 @@ import logging from typing import cast +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from core.agent.cot_chat_agent_runner import CotChatAgentRunner @@ -15,9 +18,6 @@ from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from models.model import App, Conversation, Message diff --git a/api/core/app/apps/base_app_generate_response_converter.py b/api/core/app/apps/base_app_generate_response_converter.py index a92e3dd2ea..66390116d4 100644 --- a/api/core/app/apps/base_app_generate_response_converter.py +++ b/api/core/app/apps/base_app_generate_response_converter.py @@ -3,10 +3,11 @@ from abc import ABC, abstractmethod from collections.abc import Generator, Mapping from typing import Any, Union +from graphon.model_runtime.errors.invoke import InvokeError + from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/base_app_generator.py b/api/core/app/apps/base_app_generator.py index 20e6ac98ea..7eccd59d17 100644 --- a/api/core/app/apps/base_app_generator.py +++ b/api/core/app/apps/base_app_generator.py @@ -1,27 +1,89 @@ from collections.abc import Generator, Mapping, Sequence +from contextlib import AbstractContextManager, nullcontext from typing import TYPE_CHECKING, Any, Union, final +from graphon.enums import NodeType +from graphon.file import File, FileUploadConfig +from graphon.variables.input_entities import VariableEntityType from sqlalchemy.orm import Session -from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.enums import NodeType -from dify_graph.file import File, FileUploadConfig -from dify_graph.repositories.draft_variable_repository import ( +from core.app.apps.draft_variable_saver import ( DraftVariableSaver, DraftVariableSaverFactory, NoopDraftVariableSaver, ) -from dify_graph.variables.input_entities import VariableEntityType +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope, bind_file_access_scope +from extensions.ext_database import db from factories import file_factory from libs.orjson import orjson_dumps from models import Account, EndUser from services.workflow_draft_variable_service import DraftVariableSaver as DraftVariableSaverImpl if TYPE_CHECKING: - from dify_graph.variables.input_entities import VariableEntity + from graphon.variables.input_entities import VariableEntity + + +@final +class _DebuggerDraftVariableSaver: + """Adapter that binds SQLAlchemy session setup outside the saver port.""" + + def __init__( + self, + *, + account: Account, + app_id: str, + node_id: str, + node_type: NodeType, + node_execution_id: str, + enclosing_node_id: str | None = None, + ) -> None: + self._account = account + self._app_id = app_id + self._node_id = node_id + self._node_type = node_type + self._node_execution_id = node_execution_id + self._enclosing_node_id = enclosing_node_id + + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + with Session(db.engine) as session, session.begin(): + DraftVariableSaverImpl( + session=session, + app_id=self._app_id, + node_id=self._node_id, + node_type=self._node_type, + node_execution_id=self._node_execution_id, + enclosing_node_id=self._enclosing_node_id, + user=self._account, + ).save(process_data, outputs) class BaseAppGenerator: + _file_access_controller: DatabaseFileAccessController = DatabaseFileAccessController() + + @staticmethod + def _bind_file_access_scope( + *, + tenant_id: str, + user: Account | EndUser, + invoke_from: InvokeFrom, + ) -> AbstractContextManager[None]: + """Bind request-scoped file ownership markers for downstream file lookups.""" + + user_id = getattr(user, "id", None) + if not isinstance(user_id, str) or not user_id: + return nullcontext() + + user_from = UserFrom.ACCOUNT if isinstance(user, Account) else UserFrom.END_USER + return bind_file_access_scope( + FileAccessScope( + tenant_id=tenant_id, + user_id=user_id, + user_from=user_from, + invoke_from=invoke_from, + ) + ) + def _prepare_user_inputs( self, *, @@ -50,6 +112,7 @@ class BaseAppGenerator: allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), strict_type_validation=strict_type_validation, + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, dict) and entity_dictionary[k].type == VariableEntityType.FILE @@ -64,6 +127,7 @@ class BaseAppGenerator: allowed_file_extensions=entity_dictionary[k].allowed_file_extensions or [], allowed_file_upload_methods=entity_dictionary[k].allowed_file_upload_methods or [], ), + access_controller=self._file_access_controller, ) for k, v in user_inputs.items() if isinstance(v, list) @@ -226,32 +290,30 @@ class BaseAppGenerator: assert isinstance(account, Account) def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - return DraftVariableSaverImpl( - session=session, + return _DebuggerDraftVariableSaver( + account=account, app_id=app_id, node_id=node_id, node_type=node_type, node_execution_id=node_execution_id, enclosing_node_id=enclosing_node_id, - user=account, ) else: def draft_var_saver_factory( - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: + _ = app_id, node_id, node_type, node_execution_id, enclosing_node_id return NoopDraftVariableSaver() return draft_var_saver_factory diff --git a/api/core/app/apps/base_app_queue_manager.py b/api/core/app/apps/base_app_queue_manager.py index 5addd41815..20bf81aeec 100644 --- a/api/core/app/apps/base_app_queue_manager.py +++ b/api/core/app/apps/base_app_queue_manager.py @@ -7,6 +7,7 @@ from enum import IntEnum, auto from typing import Any from cachetools import TTLCache, cachedmethod +from graphon.runtime import GraphRuntimeState from redis.exceptions import RedisError from sqlalchemy.orm import DeclarativeMeta @@ -20,7 +21,6 @@ from core.app.entities.queue_entities import ( QueueStopEvent, WorkflowQueueMessage, ) -from dify_graph.runtime import GraphRuntimeState from extensions.ext_redis import redis_client logger = logging.getLogger(__name__) @@ -61,27 +61,30 @@ class AppQueueManager(ABC): listen_timeout = dify_config.APP_MAX_EXECUTION_TIME start_time = time.time() last_ping_time: int | float = 0 - while True: - try: - message = self._q.get(timeout=1) - if message is None: - break + try: + while True: + try: + message = self._q.get(timeout=1) + if message is None: + break - yield message - except queue.Empty: - continue - finally: - elapsed_time = time.time() - start_time - if elapsed_time >= listen_timeout or self._is_stopped(): - # publish two messages to make sure the client can receive the stop signal - # and stop listening after the stop signal processed - self.publish( - QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE - ) + yield message + except queue.Empty: + continue + finally: + elapsed_time = time.time() - start_time + if elapsed_time >= listen_timeout or self._is_stopped(): + # publish two messages to make sure the client can receive the stop signal + # and stop listening after the stop signal processed + self.publish( + QueueStopEvent(stopped_by=QueueStopEvent.StopBy.USER_MANUAL), PublishFrom.TASK_PIPELINE + ) - if elapsed_time // 10 > last_ping_time: - self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) - last_ping_time = elapsed_time // 10 + if elapsed_time // 10 > last_ping_time: + self.publish(QueuePingEvent(), PublishFrom.TASK_PIPELINE) + last_ping_time = elapsed_time // 10 + finally: + self._graph_runtime_state = None # Release reference once consumers finish or close the generator. def stop_listen(self): """ @@ -90,7 +93,6 @@ class AppQueueManager(ABC): """ self._clear_task_belong_cache() self._q.put(None) - self._graph_runtime_state = None # Release reference to allow GC to reclaim memory def _clear_task_belong_cache(self) -> None: """ diff --git a/api/core/app/apps/base_app_runner.py b/api/core/app/apps/base_app_runner.py index 11fcbb7561..4aebc0cb30 100644 --- a/api/core/app/apps/base_app_runner.py +++ b/api/core/app/apps/base_app_runner.py @@ -5,6 +5,17 @@ from collections.abc import Generator, Mapping, Sequence from mimetypes import guess_extension from typing import TYPE_CHECKING, Any, Union +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError + from core.app.app_config.entities import ExternalDataVariableEntity, PromptTemplateEntity from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import ( @@ -29,22 +40,12 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.simple_prompt_transform import ModelMode, SimplePromptTransform from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from extensions.ext_database import db from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import App, AppMode, Message, MessageAnnotation, MessageFile if TYPE_CHECKING: - from dify_graph.file.models import File + from graphon.file import File _logger = logging.getLogger(__name__) diff --git a/api/core/app/apps/chat/app_generator.py b/api/core/app/apps/chat/app_generator.py index 91cf54c774..b675a87382 100644 --- a/api/core/app/apps/chat/app_generator.py +++ b/api/core/app/apps/chat/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -5,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from configs import dify_config @@ -20,7 +22,6 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from models import Account @@ -120,89 +121,96 @@ class ChatAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = ChatAppConfigManager.get_app_config( - app_model=app_model, - app_model_config=app_model_config, - conversation=conversation, - override_config_dict=override_model_config_dict, - ) + # convert to app config + app_config = ChatAppConfigManager.get_app_config( + app_model=app_model, + app_model_config=app_model_config, + conversation=conversation, + override_config_dict=override_model_config_dict, + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = ChatAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - conversation_id=conversation.id if conversation else None, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, - user_id=user.id, - invoke_from=invoke_from, - extras=extras, - trace_manager=trace_manager, - stream=streaming, - ) + # init application generate entity + application_generate_entity = ChatAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + conversation_id=conversation.id if conversation else None, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL, + user_id=user.id, + invoke_from=invoke_from, + extras=extras, + trace_manager=trace_manager, + stream=streaming, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity, conversation) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity, conversation) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation_id=conversation.id, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return ChatAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, diff --git a/api/core/app/apps/chat/app_runner.py b/api/core/app/apps/chat/app_runner.py index f63b38fc86..050f763e95 100644 --- a/api/core/app/apps/chat/app_runner.py +++ b/api/core/app/apps/chat/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -15,8 +17,6 @@ from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import App, Conversation, Message @@ -223,7 +223,6 @@ class ChatAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/core/app/apps/common/graph_runtime_state_support.py b/api/core/app/apps/common/graph_runtime_state_support.py index 6a8e436163..ab277857fe 100644 --- a/api/core/app/apps/common/graph_runtime_state_support.py +++ b/api/core/app/apps/common/graph_runtime_state_support.py @@ -4,7 +4,9 @@ from __future__ import annotations from typing import TYPE_CHECKING -from dify_graph.runtime import GraphRuntimeState +from graphon.runtime import GraphRuntimeState + +from core.workflow.system_variables import SystemVariableKey, get_system_text if TYPE_CHECKING: from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline @@ -30,10 +32,10 @@ class GraphRuntimeStateSupport: return self._resolve_graph_runtime_state(graph_runtime_state) def _extract_workflow_run_id(self, graph_runtime_state: GraphRuntimeState) -> str: - system_variables = graph_runtime_state.variable_pool.system_variables - if not system_variables or not system_variables.workflow_execution_id: + workflow_run_id = get_system_text(graph_runtime_state.variable_pool, SystemVariableKey.WORKFLOW_EXECUTION_ID) + if not workflow_run_id: raise ValueError("workflow_execution_id missing from runtime state") - return str(system_variables.workflow_execution_id) + return workflow_run_id def _resolve_graph_runtime_state( self, diff --git a/api/core/app/apps/common/workflow_response_converter.py b/api/core/app/apps/common/workflow_response_converter.py index 621b0d8cf3..a515531616 100644 --- a/api/core/app/apps/common/workflow_response_converter.py +++ b/api/core/app/apps/common/workflow_response_converter.py @@ -1,3 +1,4 @@ +import json import logging import time from collections.abc import Mapping, Sequence @@ -5,6 +6,19 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, NewType, TypedDict, Union +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import FILE_MODEL_IDENTITY, File +from graphon.runtime import GraphRuntimeState +from graphon.variables.segments import ArrayFileSegment, FileSegment, Segment +from graphon.variables.variables import Variable +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from sqlalchemy.orm import Session @@ -50,21 +64,9 @@ from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.trigger_manager import TriggerManager +from core.workflow.human_input_forms import load_form_tokens_by_form_id +from core.workflow.system_variables import SystemVariableKey, system_variables_to_mapping from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import ( - BuiltinNodeTypes, - SystemVariableKey, - WorkflowExecutionStatus, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import FILE_MODEL_IDENTITY, File -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import ArrayFileSegment, FileSegment, Segment -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models import Account, EndUser @@ -111,11 +113,11 @@ class WorkflowResponseConverter: *, application_generate_entity: Union[AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity], user: Union[Account, EndUser], - system_variables: SystemVariable, + system_variables: Sequence[Variable], ): self._application_generate_entity = application_generate_entity self._user = user - self._system_variables = system_variables + self._system_variables = system_variables_to_mapping(system_variables) self._workflow_inputs = self._prepare_workflow_inputs() # Disable truncation for SERVICE_API calls to keep backward compatibility. @@ -133,7 +135,7 @@ class WorkflowResponseConverter: # ------------------------------------------------------------------ def _prepare_workflow_inputs(self) -> Mapping[str, Any]: inputs = dict(self._application_generate_entity.inputs) - for field_name, value in self._system_variables.to_dict().items(): + for field_name, value in self._system_variables.items(): # TODO(@future-refactor): store system variables separately from user inputs so we don't # need to flatten `sys.*` entries into the input payload just for rerun/export tooling. if field_name == SystemVariableKey.CONVERSATION_ID: @@ -318,13 +320,23 @@ class WorkflowResponseConverter: pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons] human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)] expiration_times_by_form_id: dict[str, datetime] = {} + display_in_ui_by_form_id: dict[str, bool] = {} + form_token_by_form_id: dict[str, str] = {} if human_input_form_ids: - stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where( - HumanInputForm.id.in_(human_input_form_ids) - ) + stmt = select( + HumanInputForm.id, + HumanInputForm.expiration_time, + HumanInputForm.form_definition, + ).where(HumanInputForm.id.in_(human_input_form_ids)) with Session(bind=db.engine) as session: - for form_id, expiration_time in session.execute(stmt): + for form_id, expiration_time, form_definition in session.execute(stmt): expiration_times_by_form_id[str(form_id)] = expiration_time + try: + definition_payload = json.loads(form_definition) if form_definition else {} + except (TypeError, json.JSONDecodeError): + definition_payload = {} + display_in_ui_by_form_id[str(form_id)] = bool(definition_payload.get("display_in_ui")) + form_token_by_form_id = load_form_tokens_by_form_id(human_input_form_ids, session=session) responses: list[StreamResponse] = [] @@ -344,8 +356,8 @@ class WorkflowResponseConverter: form_content=reason.form_content, inputs=reason.inputs, actions=reason.actions, - display_in_ui=reason.display_in_ui, - form_token=reason.form_token, + display_in_ui=display_in_ui_by_form_id.get(reason.form_id, False), + form_token=form_token_by_form_id.get(reason.form_id), resolved_default_values=reason.resolved_default_values, expiration_time=int(expiration_time.timestamp()), ), diff --git a/api/core/app/apps/completion/app_generator.py b/api/core/app/apps/completion/app_generator.py index 002b914ef1..a62c5b80b5 100644 --- a/api/core/app/apps/completion/app_generator.py +++ b/api/core/app/apps/completion/app_generator.py @@ -1,3 +1,4 @@ +import contextvars import logging import threading import uuid @@ -5,6 +6,7 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, overload from flask import Flask, copy_current_request_context, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from sqlalchemy import select @@ -20,7 +22,6 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager from core.app.entities.app_invoke_entities import CompletionAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from extensions.ext_database import db from factories import file_factory from models import Account, App, EndUser, Message @@ -108,83 +109,90 @@ class CompletionAppGenerator(MessageBasedAppGenerator): # # For implementation reference, see the `_parse_file` function and # `DraftWorkflowNodeRunApi` class which handle this properly. - files = args["files"] if args.get("files") else [] - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict or app_model_config.to_dict()) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files = args["files"] if args.get("files") else [] + file_extra_config = FileUploadConfigManager.convert( + override_model_config_dict or app_model_config.to_dict() ) - else: - file_objs = [] + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict + ) - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, user_id=user.id if isinstance(user, Account) else user.session_id + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - file_upload_config=file_extra_config, - inputs=self._prepare_user_inputs( - user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id - ), - query=query, - files=list(file_objs), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - extras={}, - trace_manager=trace_manager, - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + file_upload_config=file_extra_config, + inputs=self._prepare_user_inputs( + user_inputs=inputs, variables=app_config.variables, tenant_id=app_model.tenant_id + ), + query=query, + files=list(file_objs), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + extras={}, + trace_manager=trace_manager, + ) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=streaming, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=streaming, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def _generate_worker( self, @@ -280,71 +288,76 @@ class CompletionAppGenerator(MessageBasedAppGenerator): model_dict["completion_params"] = completion_params override_model_config_dict["model"] = model_dict - # parse files - file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) - if file_extra_config: - file_objs = file_factory.build_from_mappings( - mappings=message.message_files, - tenant_id=app_model.tenant_id, - config=file_extra_config, + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + # parse files + file_extra_config = FileUploadConfigManager.convert(override_model_config_dict) + if file_extra_config: + file_objs = file_factory.build_from_mappings( + mappings=message.message_files, + tenant_id=app_model.tenant_id, + config=file_extra_config, + access_controller=self._file_access_controller, + ) + else: + file_objs = [] + + # convert to app config + app_config = CompletionAppConfigManager.get_app_config( + app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict ) - else: - file_objs = [] - # convert to app config - app_config = CompletionAppConfigManager.get_app_config( - app_model=app_model, app_model_config=app_model_config, override_config_dict=override_model_config_dict - ) + # init application generate entity + application_generate_entity = CompletionAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + model_conf=ModelConfigConverter.convert(app_config), + inputs=message.inputs, + query=message.query, + files=list(file_objs), + user_id=user.id, + stream=stream, + invoke_from=invoke_from, + extras={}, + ) - # init application generate entity - application_generate_entity = CompletionAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - model_conf=ModelConfigConverter.convert(app_config), - inputs=message.inputs, - query=message.query, - files=list(file_objs), - user_id=user.id, - stream=stream, - invoke_from=invoke_from, - extras={}, - ) + # init generate records + (conversation, message) = self._init_generate_records(application_generate_entity) - # init generate records - (conversation, message) = self._init_generate_records(application_generate_entity) - - # init queue manager - queue_manager = MessageBasedAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - conversation_id=conversation.id, - app_mode=conversation.mode, - message_id=message.id, - ) - - # new thread with request context - @copy_current_request_context - def worker_with_context(): - return self._generate_worker( - flask_app=current_app._get_current_object(), # type: ignore - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, + # init queue manager + queue_manager = MessageBasedAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + conversation_id=conversation.id, + app_mode=conversation.mode, message_id=message.id, ) - worker_thread = threading.Thread(target=worker_with_context) + context = contextvars.copy_context() - worker_thread.start() + # new thread with request context + @copy_current_request_context + def worker_with_context(): + return context.run( + self._generate_worker, + flask_app=current_app._get_current_object(), # type: ignore + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + message_id=message.id, + ) - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - queue_manager=queue_manager, - conversation=conversation, - message=message, - user=user, - stream=stream, - ) + worker_thread = threading.Thread(target=worker_with_context) - return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + worker_thread.start() + + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + queue_manager=queue_manager, + conversation=conversation, + message=message, + user=user, + stream=stream, + ) + + return CompletionAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) diff --git a/api/core/app/apps/completion/app_runner.py b/api/core/app/apps/completion/app_runner.py index 56a4519879..b216f7cf7b 100644 --- a/api/core/app/apps/completion/app_runner.py +++ b/api/core/app/apps/completion/app_runner.py @@ -1,6 +1,8 @@ import logging from typing import cast +from graphon.file import File +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from sqlalchemy import select from core.app.apps.base_app_queue_manager import AppQueueManager @@ -13,8 +15,6 @@ from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCa from core.model_manager import ModelInstance from core.moderation.base import ModerationError from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from extensions.ext_database import db from models.model import App, Message @@ -181,7 +181,6 @@ class CompletionAppRunner(AppRunner): model_parameters=application_generate_entity.model_conf.parameters, stop=stop, stream=application_generate_entity.stream, - user=application_generate_entity.user_id, ) # handle invoke result diff --git a/api/dify_graph/repositories/draft_variable_repository.py b/api/core/app/apps/draft_variable_saver.py similarity index 65% rename from api/dify_graph/repositories/draft_variable_repository.py rename to api/core/app/apps/draft_variable_saver.py index b2ebfacffd..24018012c5 100644 --- a/api/dify_graph/repositories/draft_variable_repository.py +++ b/api/core/app/apps/draft_variable_saver.py @@ -4,31 +4,30 @@ import abc from collections.abc import Mapping from typing import Any, Protocol -from sqlalchemy.orm import Session - -from dify_graph.enums import NodeType +from graphon.enums import NodeType class DraftVariableSaver(Protocol): @abc.abstractmethod - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + """Persist node draft variables for a completed execution.""" + raise NotImplementedError class DraftVariableSaverFactory(Protocol): @abc.abstractmethod def __call__( self, - session: Session, app_id: str, node_id: str, node_type: NodeType, node_execution_id: str, enclosing_node_id: str | None = None, ) -> DraftVariableSaver: - pass + """Build a saver bound to a concrete node execution.""" + raise NotImplementedError class NoopDraftVariableSaver(DraftVariableSaver): - def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None): - pass + def save(self, process_data: Mapping[str, Any] | None, outputs: Mapping[str, Any] | None) -> None: + return None diff --git a/api/core/app/apps/message_based_app_generator.py b/api/core/app/apps/message_based_app_generator.py index 44d10d79b8..fe61224ada 100644 --- a/api/core/app/apps/message_based_app_generator.py +++ b/api/core/app/apps/message_based_app_generator.py @@ -28,6 +28,7 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.prompt.utils.prompt_template_parser import PromptTemplateParser +from core.workflow.file_reference import resolve_file_record_id from extensions.ext_database import db from extensions.ext_redis import get_pubsub_broadcast_channel from libs.broadcast_channel.channel import Topic @@ -227,7 +228,7 @@ class MessageBasedAppGenerator(BaseAppGenerator): transfer_method=file.transfer_method, belongs_to=MessageFileBelongsTo.USER, url=file.remote_url, - upload_file_id=file.related_id, + upload_file_id=resolve_file_record_id(file.reference), created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER), created_by=account_id or end_user_id or "", ) diff --git a/api/core/app/apps/pipeline/pipeline_generator.py b/api/core/app/apps/pipeline/pipeline_generator.py index 19d67eb108..fa242003a2 100644 --- a/api/core/app/apps/pipeline/pipeline_generator.py +++ b/api/core/app/apps/pipeline/pipeline_generator.py @@ -10,6 +10,8 @@ from collections.abc import Generator, Mapping from typing import Any, Literal, Union, cast, overload from flask import Flask, current_app +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -18,6 +20,7 @@ import contexts from configs import dify_config from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.pipeline.pipeline_config_manager import PipelineConfigManager from core.app.apps.pipeline.pipeline_queue_manager import PipelineQueueManager @@ -34,12 +37,11 @@ from core.datasource.entities.datasource_entities import ( from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin from core.entities.knowledge_entities import PipelineDataset, PipelineDocument from core.rag.index_processor.constant.built_in_field import BuiltInField -from core.repositories.factory import DifyCoreRepositoryFactory -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from extensions.ext_database import db from libs.flask_utils import preserve_flask_contexts from models import Account, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom diff --git a/api/core/app/apps/pipeline/pipeline_runner.py b/api/core/app/apps/pipeline/pipeline_runner.py index e767766bdb..4c188dac68 100644 --- a/api/core/app/apps/pipeline/pipeline_runner.py +++ b/api/core/app/apps/pipeline/pipeline_runner.py @@ -2,6 +2,14 @@ import logging import time from typing import cast +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader +from graphon.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.pipeline.pipeline_config_manager import PipelineConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner @@ -12,18 +20,11 @@ from core.app.entities.app_invoke_entities import ( build_dify_run_context, ) from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.graph_init_params import GraphInitParams -from dify_graph.enums import WorkflowType -from dify_graph.graph import Graph -from dify_graph.graph_events import GraphEngineEvent, GraphRunFailedEvent -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables.variables import RAGPipelineVariable, RAGPipelineVariableInput from extensions.ext_database import db from models.dataset import Document, Pipeline from models.model import EndUser @@ -106,13 +107,14 @@ class PipelineRunner(WorkflowBasedAppRunner): workflow=workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs files = self.application_generate_entity.files # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=files, user_id=user_id, app_id=app_config.app_id, @@ -142,19 +144,25 @@ class PipelineRunner(WorkflowBasedAppRunner): ) ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=workflow.environment_variables, - conversation_variables=[], - rag_pipeline_variables=rag_pipeline_variables, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=workflow.environment_variables, + rag_pipeline_variables=rag_pipeline_variables, + ), ) + root_node_id = self.application_generate_entity.start_node_id or get_default_root_node_id( + workflow.graph_dict + ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) # init graph graph = self._init_rag_pipeline_graph( graph_runtime_state=graph_runtime_state, - start_node_id=self.application_generate_entity.start_node_id, + start_node_id=root_node_id, workflow=workflow, user_from=user_from, invoke_from=invoke_from, diff --git a/api/core/app/apps/workflow/app_generator.py b/api/core/app/apps/workflow/app_generator.py index 6fbe19a3b2..9618ab35c6 100644 --- a/api/core/app/apps/workflow/app_generator.py +++ b/api/core/app/apps/workflow/app_generator.py @@ -8,6 +8,10 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal, Union, overload from flask import Flask, current_app +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError +from graphon.runtime import GraphRuntimeState +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader from pydantic import ValidationError from sqlalchemy import select from sqlalchemy.orm import sessionmaker @@ -17,6 +21,7 @@ from configs import dify_config from core.app.app_config.features.file_upload.manager import FileUploadConfigManager from core.app.apps.base_app_generator import BaseAppGenerator from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.apps.exc import GenerateTaskStoppedError from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.apps.workflow.app_queue_manager import WorkflowAppQueueManager @@ -30,13 +35,7 @@ from core.db.session_factory import session_factory from core.helper.trace_id_helper import extract_external_trace_id_from_args from core.ops.ops_trace_manager import TraceQueueManager from core.repositories import DifyCoreRepositoryFactory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository from extensions.ext_database import db from factories import file_factory from libs.flask_utils import preserve_flask_contexts @@ -129,107 +128,109 @@ class WorkflowAppGenerator(BaseAppGenerator): graph_engine_layers: Sequence[GraphEngineLayer] = (), pause_state_config: PauseStateLayerConfig | None = None, ) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: - files: Sequence[Mapping[str, Any]] = args.get("files") or [] + with self._bind_file_access_scope(tenant_id=app_model.tenant_id, user=user, invoke_from=invoke_from): + files: Sequence[Mapping[str, Any]] = args.get("files") or [] - # parse files - # TODO(QuantumGhost): Move file parsing logic to the API controller layer - # for better separation of concerns. - # - # For implementation reference, see the `_parse_file` function and - # `DraftWorkflowNodeRunApi` class which handle this properly. - file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) - system_files = file_factory.build_from_mappings( - mappings=files, - tenant_id=app_model.tenant_id, - config=file_extra_config, - strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, - ) - - # convert to app config - app_config = WorkflowAppConfigManager.get_app_config( - app_model=app_model, - workflow=workflow, - ) - - # get tracing instance - trace_manager = TraceQueueManager( - app_id=app_model.id, - user_id=user.id if isinstance(user, Account) else user.session_id, - ) - - inputs: Mapping[str, Any] = args["inputs"] - - extras = { - **extract_external_trace_id_from_args(args), - } - workflow_run_id = str(workflow_run_id or uuid.uuid4()) - # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args - # trigger shouldn't prepare user inputs - if self._should_prepare_user_inputs(args): - inputs = self._prepare_user_inputs( - user_inputs=inputs, - variables=app_config.variables, + # parse files + # TODO(QuantumGhost): Move file parsing logic to the API controller layer + # for better separation of concerns. + # + # For implementation reference, see the `_parse_file` function and + # `DraftWorkflowNodeRunApi` class which handle this properly. + file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False) + system_files = file_factory.build_from_mappings( + mappings=files, tenant_id=app_model.tenant_id, + config=file_extra_config, strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + access_controller=self._file_access_controller, ) - # init application generate entity - application_generate_entity = WorkflowAppGenerateEntity( - task_id=str(uuid.uuid4()), - app_config=app_config, - file_upload_config=file_extra_config, - inputs=inputs, - files=list(system_files), - user_id=user.id, - stream=streaming, - invoke_from=invoke_from, - call_depth=call_depth, - trace_manager=trace_manager, - workflow_execution_id=workflow_run_id, - extras=extras, - ) - contexts.plugin_tool_providers.set({}) - contexts.plugin_tool_providers_lock.set(threading.Lock()) + # convert to app config + app_config = WorkflowAppConfigManager.get_app_config( + app_model=app_model, + workflow=workflow, + ) - # Create repositories - # - # Create session factory - session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) - # Create workflow execution(aka workflow run) repository - if triggered_from is not None: - # Use explicitly provided triggered_from (for async triggers) - workflow_triggered_from = triggered_from - elif invoke_from == InvokeFrom.DEBUGGER: - workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING - else: - workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN - workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=workflow_triggered_from, - ) - # Create workflow node execution repository - workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( - session_factory=session_factory, - user=user, - app_id=application_generate_entity.app_config.app_id, - triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, - ) + # get tracing instance + trace_manager = TraceQueueManager( + app_id=app_model.id, + user_id=user.id if isinstance(user, Account) else user.session_id, + ) - return self._generate( - app_model=app_model, - workflow=workflow, - user=user, - application_generate_entity=application_generate_entity, - invoke_from=invoke_from, - workflow_execution_repository=workflow_execution_repository, - workflow_node_execution_repository=workflow_node_execution_repository, - streaming=streaming, - root_node_id=root_node_id, - graph_engine_layers=graph_engine_layers, - pause_state_config=pause_state_config, - ) + inputs: Mapping[str, Any] = args["inputs"] + + extras = { + **extract_external_trace_id_from_args(args), + } + workflow_run_id = str(workflow_run_id or uuid.uuid4()) + # FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args + # trigger shouldn't prepare user inputs + if self._should_prepare_user_inputs(args): + inputs = self._prepare_user_inputs( + user_inputs=inputs, + variables=app_config.variables, + tenant_id=app_model.tenant_id, + strict_type_validation=True if invoke_from == InvokeFrom.SERVICE_API else False, + ) + # init application generate entity + application_generate_entity = WorkflowAppGenerateEntity( + task_id=str(uuid.uuid4()), + app_config=app_config, + file_upload_config=file_extra_config, + inputs=inputs, + files=list(system_files), + user_id=user.id, + stream=streaming, + invoke_from=invoke_from, + call_depth=call_depth, + trace_manager=trace_manager, + workflow_execution_id=workflow_run_id, + extras=extras, + ) + + contexts.plugin_tool_providers.set({}) + contexts.plugin_tool_providers_lock.set(threading.Lock()) + + # Create repositories + # + # Create session factory + session_factory = sessionmaker(bind=db.engine, expire_on_commit=False) + # Create workflow execution(aka workflow run) repository + if triggered_from is not None: + # Use explicitly provided triggered_from (for async triggers) + workflow_triggered_from = triggered_from + elif invoke_from == InvokeFrom.DEBUGGER: + workflow_triggered_from = WorkflowRunTriggeredFrom.DEBUGGING + else: + workflow_triggered_from = WorkflowRunTriggeredFrom.APP_RUN + workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=workflow_triggered_from, + ) + # Create workflow node execution repository + workflow_node_execution_repository = DifyCoreRepositoryFactory.create_workflow_node_execution_repository( + session_factory=session_factory, + user=user, + app_id=application_generate_entity.app_config.app_id, + triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, + ) + + return self._generate( + app_model=app_model, + workflow=workflow, + user=user, + application_generate_entity=application_generate_entity, + invoke_from=invoke_from, + workflow_execution_repository=workflow_execution_repository, + workflow_node_execution_repository=workflow_node_execution_repository, + streaming=streaming, + root_node_id=root_node_id, + graph_engine_layers=graph_engine_layers, + pause_state_config=pause_state_config, + ) def resume( self, @@ -292,62 +293,67 @@ class WorkflowAppGenerator(BaseAppGenerator): :param workflow_node_execution_repository: repository for workflow node execution :param streaming: is stream """ - graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) + with self._bind_file_access_scope( + tenant_id=application_generate_entity.app_config.tenant_id, + user=user, + invoke_from=invoke_from, + ): + graph_layers: list[GraphEngineLayer] = list(graph_engine_layers) - # init queue manager - queue_manager = WorkflowAppQueueManager( - task_id=application_generate_entity.task_id, - user_id=application_generate_entity.user_id, - invoke_from=application_generate_entity.invoke_from, - app_mode=app_model.mode, - ) - - if pause_state_config is not None: - graph_layers.append( - PauseStatePersistenceLayer( - session_factory=pause_state_config.session_factory, - generate_entity=application_generate_entity, - state_owner_user_id=pause_state_config.state_owner_user_id, - ) + # init queue manager + queue_manager = WorkflowAppQueueManager( + task_id=application_generate_entity.task_id, + user_id=application_generate_entity.user_id, + invoke_from=application_generate_entity.invoke_from, + app_mode=app_model.mode, ) - # new thread with request context and contextvars - context = contextvars.copy_context() + if pause_state_config is not None: + graph_layers.append( + PauseStatePersistenceLayer( + session_factory=pause_state_config.session_factory, + generate_entity=application_generate_entity, + state_owner_user_id=pause_state_config.state_owner_user_id, + ) + ) - # release database connection, because the following new thread operations may take a long time - db.session.close() + # new thread with request context and contextvars + context = contextvars.copy_context() - worker_thread = threading.Thread( - target=self._generate_worker, - kwargs={ - "flask_app": current_app._get_current_object(), # type: ignore - "application_generate_entity": application_generate_entity, - "queue_manager": queue_manager, - "context": context, - "variable_loader": variable_loader, - "root_node_id": root_node_id, - "workflow_execution_repository": workflow_execution_repository, - "workflow_node_execution_repository": workflow_node_execution_repository, - "graph_engine_layers": tuple(graph_layers), - "graph_runtime_state": graph_runtime_state, - }, - ) + # release database connection, because the following new thread operations may take a long time + db.session.close() - worker_thread.start() + worker_thread = threading.Thread( + target=self._generate_worker, + kwargs={ + "flask_app": current_app._get_current_object(), # type: ignore + "application_generate_entity": application_generate_entity, + "queue_manager": queue_manager, + "context": context, + "variable_loader": variable_loader, + "root_node_id": root_node_id, + "workflow_execution_repository": workflow_execution_repository, + "workflow_node_execution_repository": workflow_node_execution_repository, + "graph_engine_layers": tuple(graph_layers), + "graph_runtime_state": graph_runtime_state, + }, + ) - draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) + worker_thread.start() - # return response or stream generator - response = self._handle_response( - application_generate_entity=application_generate_entity, - workflow=workflow, - queue_manager=queue_manager, - user=user, - draft_var_saver_factory=draft_var_saver_factory, - stream=streaming, - ) + draft_var_saver_factory = self._get_draft_var_saver_factory(invoke_from, user) - return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) + # return response or stream generator + response = self._handle_response( + application_generate_entity=application_generate_entity, + workflow=workflow, + queue_manager=queue_manager, + user=user, + draft_var_saver_factory=draft_var_saver_factory, + stream=streaming, + ) + + return WorkflowAppGenerateResponseConverter.convert(response=response, invoke_from=invoke_from) def single_iteration_generate( self, diff --git a/api/core/app/apps/workflow/app_runner.py b/api/core/app/apps/workflow/app_runner.py index caea8b6b95..2cb8088971 100644 --- a/api/core/app/apps/workflow/app_runner.py +++ b/api/core/app/apps/workflow/app_runner.py @@ -3,20 +3,22 @@ import time from collections.abc import Sequence from typing import cast +from graphon.enums import WorkflowType +from graphon.graph_engine.command_channels import RedisChannel +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import VariableLoader + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfig from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.node_factory import get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.enums import WorkflowType -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import VariableLoader from extensions.ext_redis import redis_client from extensions.otel import WorkflowAppRunnerHandler, trace_span from libs.datetime_utils import naive_utc_now @@ -91,12 +93,13 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow=self._workflow, single_iteration_run=self.application_generate_entity.single_iteration_run, single_loop_run=self.application_generate_entity.single_loop_run, + user_id=self.application_generate_entity.user_id, ) else: inputs = self.application_generate_entity.inputs # Create a variable pool. - system_inputs = SystemVariable( + system_inputs = build_system_variables( files=self.application_generate_entity.files, user_id=self._sys_user_id, app_id=app_config.app_id, @@ -104,12 +107,16 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): workflow_id=app_config.workflow_id, workflow_execution_id=self.application_generate_entity.workflow_execution_id, ) - variable_pool = VariablePool( - system_variables=system_inputs, - user_inputs=inputs, - environment_variables=self._workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_inputs, + environment_variables=self._workflow.environment_variables, + ), ) + root_node_id = self._root_node_id or get_default_root_node_id(self._workflow.graph_dict) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) graph = self._init_graph( @@ -120,7 +127,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner): user_id=self.application_generate_entity.user_id, user_from=user_from, invoke_from=invoke_from, - root_node_id=self._root_node_id, + root_node_id=root_node_id, ) # RUN WORKFLOW diff --git a/api/core/app/apps/workflow/generate_task_pipeline.py b/api/core/app/apps/workflow/generate_task_pipeline.py index bd6e2a0302..49af169e88 100644 --- a/api/core/app/apps/workflow/generate_task_pipeline.py +++ b/api/core/app/apps/workflow/generate_task_pipeline.py @@ -4,12 +4,16 @@ from collections.abc import Callable, Generator from contextlib import contextmanager from typing import Union +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState from sqlalchemy.orm import Session from constants.tts_auto_play_timeout import TTS_AUTO_PLAY_TIMEOUT, TTS_AUTO_PLAY_YIELD_CPU_TIME from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter +from core.app.apps.draft_variable_saver import DraftVariableSaverFactory from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import ( AppQueueEvent, @@ -55,11 +59,7 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.repositories.draft_variable_repository import DraftVariableSaverFactory -from dify_graph.runtime import GraphRuntimeState -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from models import Account from models.enums import CreatorUserRole @@ -104,7 +104,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): self._invoke_from = queue_manager.invoke_from self._draft_var_saver_factory = draft_var_saver_factory self._workflow = workflow - self._workflow_system_variables = SystemVariable( + self._workflow_system_variables = build_system_variables( files=application_generate_entity.files, user_id=user_session_id, app_id=application_generate_entity.app_config.app_id, @@ -728,13 +728,11 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport): return response def _save_output_for_event(self, event: QueueNodeSucceededEvent | QueueNodeExceptionEvent, node_execution_id: str): - with Session(db.engine) as session, session.begin(): - saver = self._draft_var_saver_factory( - session=session, - app_id=self._application_generate_entity.app_config.app_id, - node_id=event.node_id, - node_type=event.node_type, - node_execution_id=node_execution_id, - enclosing_node_id=event.in_loop_id or event.in_iteration_id, - ) - saver.save(event.process_data, event.outputs) + saver = self._draft_var_saver_factory( + app_id=self._application_generate_entity.app_config.app_id, + node_id=event.node_id, + node_type=event.node_type, + node_execution_id=node_execution_id, + enclosing_node_id=event.in_loop_id or event.in_iteration_id, + ) + saver.save(event.process_data, event.outputs) diff --git a/api/core/app/apps/workflow_app_runner.py b/api/core/app/apps/workflow_app_runner.py index adc6cce9af..f68c8e60b4 100644 --- a/api/core/app/apps/workflow_app_runner.py +++ b/api/core/app/apps/workflow_app_runner.py @@ -3,6 +3,40 @@ import time from collections.abc import Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph import Graph +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( + GraphEngineEvent, + GraphRunAbortedEvent, + GraphRunFailedEvent, + GraphRunPartialSucceededEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunHumanInputFormFilledEvent, + NodeRunHumanInputFormTimeoutEvent, + NodeRunIterationFailedEvent, + NodeRunIterationNextEvent, + NodeRunIterationStartedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunLoopNextEvent, + NodeRunLoopStartedEvent, + NodeRunLoopSucceededEvent, + NodeRunRetrieverResourceEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from pydantic import ValidationError from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom @@ -34,42 +68,16 @@ from core.app.entities.queue_entities import ( ) from core.rag.entities.citation_metadata import RetrievalSourceMetadata from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.graph import Graph -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( - GraphEngineEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, +from core.workflow.system_variables import ( + build_bootstrap_variables, + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, ) -from dify_graph.graph_events.graph import GraphRunAbortedEvent -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from models.workflow import Workflow from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task @@ -156,6 +164,8 @@ class WorkflowBasedAppRunner: workflow: Workflow, single_iteration_run: Any | None = None, single_loop_run: Any | None = None, + *, + user_id: str, ) -> tuple[Graph, VariablePool, GraphRuntimeState]: """ Prepare graph, variable pool, and runtime state for single node execution @@ -173,14 +183,15 @@ class WorkflowBasedAppRunner: ValueError: If neither single_iteration_run nor single_loop_run is specified """ # Create initial runtime state with variable pool containing environment variables - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), environment_variables=workflow.environment_variables, ), - start_at=time.time(), ) + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.time()) # Determine which type of single node execution and get graph/variable_pool if single_iteration_run: @@ -191,6 +202,7 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="iteration_id", node_type_label="iteration", + user_id=user_id, ) elif single_loop_run: graph, variable_pool = self._get_graph_and_variable_pool_for_single_node_run( @@ -200,6 +212,7 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="loop_id", node_type_label="loop", + user_id=user_id, ) else: raise ValueError("Neither single_iteration_run nor single_loop_run is specified") @@ -216,6 +229,8 @@ class WorkflowBasedAppRunner: graph_runtime_state: GraphRuntimeState, node_type_filter_key: str, # 'iteration_id' or 'loop_id' node_type_label: str = "node", # 'iteration' or 'loop' for error messages + *, + user_id: str = "", ) -> tuple[Graph, VariablePool]: """ Get graph and variable pool for single node execution (iteration or loop). @@ -272,6 +287,8 @@ class WorkflowBasedAppRunner: graph_config["edges"] = edge_configs + typed_node_configs = [NodeConfigDictAdapter.validate_python(node) for node in node_configs] + # Create required parameters for Graph.init graph_init_params = GraphInitParams( workflow_id=workflow.id, @@ -279,7 +296,7 @@ class WorkflowBasedAppRunner: run_context=build_dify_run_context( tenant_id=workflow.tenant_id, app_id=self._app_id, - user_id="", + user_id=user_id, user_from=UserFrom.ACCOUNT, invoke_from=InvokeFrom.DEBUGGER, ), @@ -291,26 +308,15 @@ class WorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, ) - # init graph - graph = Graph.init( - graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True - ) - - if not graph: - raise ValueError("graph not found in workflow") - - # fetch node config from node id target_node_config = None - for node in node_configs: - if node.get("id") == node_id: + for node in typed_node_configs: + if node["id"] == node_id: target_node_config = node break if not target_node_config: raise ValueError(f"{node_type_label} node id not found in workflow graph") - target_node_config = NodeConfigDictAdapter.validate_python(target_node_config) - # Get node class node_type = target_node_config["data"].type node_version = str(target_node_config["data"].version) @@ -319,12 +325,31 @@ class WorkflowBasedAppRunner: # Use the variable pool from graph_runtime_state instead of creating a new one variable_pool = graph_runtime_state.variable_pool + preload_node_creation_variables( + variable_loader=self._variable_loader, + variable_pool=variable_pool, + selectors=[ + selector + for node_config in typed_node_configs + for selector in get_node_creation_preload_selectors( + node_type=node_config["data"].type, + node_data=node_config["data"], + ) + ], + ) + try: variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( graph_config=workflow.graph_dict, config=target_node_config ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=target_node_config["id"], + node_type=node_type, + node_data=target_node_config["data"], + variable_mapping=variable_mapping, + ) load_into_variable_pool( variable_loader=self._variable_loader, @@ -340,6 +365,14 @@ class WorkflowBasedAppRunner: tenant_id=workflow.tenant_id, ) + # init graph after constructor-time context has been loaded + graph = Graph.init( + graph_config=graph_config, node_factory=node_factory, root_node_id=node_id, skip_validation=True + ) + + if not graph: + raise ValueError("graph not found in workflow") + return graph, variable_pool @staticmethod @@ -408,7 +441,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeRetryEvent( @@ -448,7 +485,11 @@ class WorkflowBasedAppRunner: node_run_result = event.node_run_result inputs = node_run_result.inputs process_data = node_run_result.process_data - outputs = node_run_result.outputs + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=inputs, + outputs=node_run_result.outputs, + ) execution_metadata = node_run_result.metadata self._publish_event( QueueNodeSucceededEvent( @@ -466,6 +507,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunFailedEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeFailedEvent( node_execution_id=event.id, @@ -475,7 +521,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, @@ -483,6 +529,11 @@ class WorkflowBasedAppRunner: ) ) elif isinstance(event, NodeRunExceptionEvent): + outputs = project_node_outputs_for_workflow_run( + node_type=event.node_type, + inputs=event.node_run_result.inputs, + outputs=event.node_run_result.outputs, + ) self._publish_event( QueueNodeExceptionEvent( node_execution_id=event.id, @@ -492,7 +543,7 @@ class WorkflowBasedAppRunner: finished_at=event.finished_at, inputs=event.node_run_result.inputs, process_data=event.node_run_result.process_data, - outputs=event.node_run_result.outputs, + outputs=outputs, error=event.node_run_result.error or "Unknown error", execution_metadata=event.node_run_result.metadata, in_iteration_id=event.in_iteration_id, diff --git a/api/core/app/entities/app_invoke_entities.py b/api/core/app/entities/app_invoke_entities.py index ecbb1cf2f3..0cdbb5f50a 100644 --- a/api/core/app/entities/app_invoke_entities.py +++ b/api/core/app/entities/app_invoke_entities.py @@ -2,19 +2,21 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import TYPE_CHECKING, Any, Optional +from graphon.file import File, FileUploadConfig +from graphon.model_runtime.entities.model_entities import AIModelEntity from pydantic import BaseModel, ConfigDict, Field, ValidationInfo, field_validator from constants import UUID_NIL from core.app.app_config.entities import EasyUIBasedAppConfig, WorkflowUIBasedAppConfig from core.entities.provider_configuration import ProviderModelBundle -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.file import File, FileUploadConfig -from dify_graph.model_runtime.entities.model_entities import AIModelEntity if TYPE_CHECKING: from core.ops.ops_trace_manager import TraceQueueManager +DIFY_RUN_CONTEXT_KEY = "_dify" + + class UserFrom(StrEnum): ACCOUNT = "account" END_USER = "end-user" diff --git a/api/core/app/entities/queue_entities.py b/api/core/app/entities/queue_entities.py index d2a36f2a0d..5e56341f89 100644 --- a/api/core/app/entities/queue_entities.py +++ b/api/core/app/entities/queue_entities.py @@ -3,14 +3,14 @@ from datetime import datetime from enum import StrEnum, auto from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import PauseReason +from graphon.enums import NodeType, WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk class QueueEvent(StrEnum): diff --git a/api/core/app/entities/task_entities.py b/api/core/app/entities/task_entities.py index 46a8ab52f2..ba3b2e356f 100644 --- a/api/core/app/entities/task_entities.py +++ b/api/core/app/entities/task_entities.py @@ -2,14 +2,14 @@ from collections.abc import Mapping, Sequence from enum import StrEnum from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field from core.app.entities.agent_strategy import AgentStrategyInfo from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.nodes.human_input.entities import FormInput, UserAction class AnnotationReplyAccount(BaseModel): diff --git a/api/core/app/features/hosting_moderation/hosting_moderation.py b/api/core/app/features/hosting_moderation/hosting_moderation.py index 5ed1fadc41..d2d2fea4fb 100644 --- a/api/core/app/features/hosting_moderation/hosting_moderation.py +++ b/api/core/app/features/hosting_moderation/hosting_moderation.py @@ -1,8 +1,9 @@ import logging +from graphon.model_runtime.entities.message_entities import PromptMessage + from core.app.entities.app_invoke_entities import EasyUIBasedAppGenerateEntity from core.helper import moderation -from dify_graph.model_runtime.entities.message_entities import PromptMessage logger = logging.getLogger(__name__) diff --git a/api/core/app/file_access/__init__.py b/api/core/app/file_access/__init__.py new file mode 100644 index 0000000000..a75ab9781b --- /dev/null +++ b/api/core/app/file_access/__init__.py @@ -0,0 +1,11 @@ +from .controller import DatabaseFileAccessController +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, bind_file_access_scope, get_current_file_access_scope + +__all__ = [ + "DatabaseFileAccessController", + "FileAccessControllerProtocol", + "FileAccessScope", + "bind_file_access_scope", + "get_current_file_access_scope", +] diff --git a/api/core/app/file_access/controller.py b/api/core/app/file_access/controller.py new file mode 100644 index 0000000000..300c187083 --- /dev/null +++ b/api/core/app/file_access/controller.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +from collections.abc import Callable + +from sqlalchemy import select +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile +from models.enums import CreatorUserRole + +from .protocols import FileAccessControllerProtocol +from .scope import FileAccessScope, get_current_file_access_scope + + +class DatabaseFileAccessController(FileAccessControllerProtocol): + """Workflow-layer authorization helper for database-backed file lookups. + + Tenant scoping remains mandatory. When the current execution belongs to an + end user, the lookup is additionally constrained to that end user's file + ownership markers. + """ + + _scope_getter: Callable[[], FileAccessScope | None] + + def __init__( + self, + *, + scope_getter: Callable[[], FileAccessScope | None] = get_current_file_access_scope, + ) -> None: + self._scope_getter = scope_getter + + def current_scope(self) -> FileAccessScope | None: + return self._scope_getter() + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(UploadFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where( + UploadFile.created_by_role == CreatorUserRole.END_USER, + UploadFile.created_by == resolved_scope.user_id, + ) + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return stmt + + scoped_stmt = stmt.where(ToolFile.tenant_id == resolved_scope.tenant_id) + if not resolved_scope.requires_user_ownership: + return scoped_stmt + + return scoped_stmt.where(ToolFile.user_id == resolved_scope.user_id) + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(UploadFile, file_id) + + stmt = self.apply_upload_file_filters( + select(UploadFile).where(UploadFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + resolved_scope = scope or self.current_scope() + if resolved_scope is None: + return session.get(ToolFile, file_id) + + stmt = self.apply_tool_file_filters( + select(ToolFile).where(ToolFile.id == file_id), + scope=resolved_scope, + ) + return session.scalar(stmt) diff --git a/api/core/app/file_access/protocols.py b/api/core/app/file_access/protocols.py new file mode 100644 index 0000000000..8bb3eb9924 --- /dev/null +++ b/api/core/app/file_access/protocols.py @@ -0,0 +1,81 @@ +from __future__ import annotations + +from typing import Protocol + +from sqlalchemy.orm import Session +from sqlalchemy.sql import Select + +from models import ToolFile, UploadFile + +from .scope import FileAccessScope + + +class FileAccessControllerProtocol(Protocol): + """Contract for applying access rules to file lookups. + + Implementations translate an optional execution scope into query constraints + and authorized record retrieval. The contract is intentionally limited to + ownership and tenancy rules for workflow-layer file access. + """ + + def current_scope(self) -> FileAccessScope | None: + """Return the scope active for the current execution, if one exists. + + Callers use this to decide whether embedded file metadata may be trusted + or whether a fresh authorized lookup is required. + """ + ... + + def apply_upload_file_filters( + self, + stmt: Select[tuple[UploadFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[UploadFile]]: + """Return an upload-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def apply_tool_file_filters( + self, + stmt: Select[tuple[ToolFile]], + *, + scope: FileAccessScope | None = None, + ) -> Select[tuple[ToolFile]]: + """Return a tool-file query constrained by the supplied access scope. + + The returned statement must preserve the caller's existing predicates and + append only access-control conditions. + """ + ... + + def get_upload_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> UploadFile | None: + """Load one authorized upload-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... + + def get_tool_file( + self, + *, + session: Session, + file_id: str, + scope: FileAccessScope | None = None, + ) -> ToolFile | None: + """Load one authorized tool-file record for the given identifier. + + Returns ``None`` when the file does not exist or when the scope does not + permit access to that record. + """ + ... diff --git a/api/core/app/file_access/scope.py b/api/core/app/file_access/scope.py new file mode 100644 index 0000000000..80d504ef1c --- /dev/null +++ b/api/core/app/file_access/scope.py @@ -0,0 +1,40 @@ +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +from contextvars import ContextVar +from dataclasses import dataclass + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom + +_current_file_access_scope: ContextVar[FileAccessScope | None] = ContextVar( + "current_file_access_scope", + default=None, +) + + +@dataclass(frozen=True, slots=True) +class FileAccessScope: + """Request-scoped ownership context used by workflow-layer file lookups.""" + + tenant_id: str + user_id: str + user_from: UserFrom + invoke_from: InvokeFrom + + @property + def requires_user_ownership(self) -> bool: + return self.user_from == UserFrom.END_USER + + +def get_current_file_access_scope() -> FileAccessScope | None: + return _current_file_access_scope.get() + + +@contextmanager +def bind_file_access_scope(scope: FileAccessScope) -> Iterator[None]: + token = _current_file_access_scope.set(scope) + try: + yield + finally: + _current_file_access_scope.reset(token) diff --git a/api/core/app/layers/conversation_variable_persist_layer.py b/api/core/app/layers/conversation_variable_persist_layer.py index d227e4e904..e09869f5f8 100644 --- a/api/core/app/layers/conversation_variable_persist_layer.py +++ b/api/core/app/layers/conversation_variable_persist_layer.py @@ -1,12 +1,20 @@ +""" +Persist conversation-scoped variable updates emitted by the graph engine. + +The graph package emits generic variable update events and stays unaware of +conversation identity or storage concerns. This layer lives in the application +core, listens to those generic events, and persists only the `conversation.*` +scope updates that matter to chat applications. +""" + import logging -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.conversation_variable_updater import ConversationVariableUpdater -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.variables import VariableBase +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, NodeRunVariableUpdatedEvent + +from core.workflow.system_variables import SystemVariableKey, get_system_text +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from services.conversation_variable_updater import ConversationVariableUpdater logger = logging.getLogger(__name__) @@ -20,41 +28,22 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer): pass def on_event(self, event: GraphEngineEvent) -> None: - if not isinstance(event, NodeRunSucceededEvent): - return - if event.node_type != BuiltinNodeTypes.VARIABLE_ASSIGNER: - return - if self.graph_runtime_state is None: + if not isinstance(event, NodeRunVariableUpdatedEvent): return - updated_variables = common_helpers.get_updated_variables(event.node_run_result.process_data) or [] - if not updated_variables: + selector = event.variable.selector + if len(selector) < 2: + logger.warning("Conversation variable selector invalid. selector=%s", selector) return - conversation_id = self.graph_runtime_state.system_variable.conversation_id + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) if conversation_id is None: return - updated_any = False - for item in updated_variables: - selector = item.selector - if len(selector) < 2: - logger.warning("Conversation variable selector invalid. selector=%s", selector) - continue - if selector[0] != CONVERSATION_VARIABLE_NODE_ID: - continue - variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, VariableBase): - logger.warning( - "Conversation variable not found in variable pool. selector=%s", - selector, - ) - continue - self._conversation_variable_updater.update(conversation_id=conversation_id, variable=variable) - updated_any = True + if selector[0] != CONVERSATION_VARIABLE_NODE_ID: + return - if updated_any: - self._conversation_variable_updater.flush() + self._conversation_variable_updater.update(conversation_id=conversation_id, variable=event.variable) def on_graph_end(self, error: Exception | None) -> None: pass diff --git a/api/core/app/layers/pause_state_persist_layer.py b/api/core/app/layers/pause_state_persist_layer.py index 4370c01a0b..79a5442130 100644 --- a/api/core/app/layers/pause_state_persist_layer.py +++ b/api/core/app/layers/pause_state_persist_layer.py @@ -1,14 +1,14 @@ from dataclasses import dataclass from typing import Annotated, Literal, Self, TypeAlias +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent from pydantic import BaseModel, Field from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunPausedEvent +from core.workflow.system_variables import SystemVariableKey, get_system_text from models.model import AppMode from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory @@ -119,7 +119,10 @@ class PauseStatePersistenceLayer(GraphEngineLayer): generate_entity=entity_wrapper, ) - workflow_run_id: str | None = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id is not None repo = self._get_repo() repo.create_workflow_pause( diff --git a/api/core/app/layers/suspend_layer.py b/api/core/app/layers/suspend_layer.py index a881fba877..1a79a9f843 100644 --- a/api/core/app/layers/suspend_layer.py +++ b/api/core/app/layers/suspend_layer.py @@ -1,6 +1,5 @@ -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunPausedEvent +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunPausedEvent class SuspendLayer(GraphEngineLayer): diff --git a/api/core/app/layers/timeslice_layer.py b/api/core/app/layers/timeslice_layer.py index d7ca45f209..8c8daf8712 100644 --- a/api/core/app/layers/timeslice_layer.py +++ b/api/core/app/layers/timeslice_layer.py @@ -3,10 +3,10 @@ import uuid from typing import ClassVar from apscheduler.schedulers.background import BackgroundScheduler # type: ignore +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent -from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import CFSPlanScheduler, SchedulerCommand diff --git a/api/core/app/layers/trigger_post_layer.py b/api/core/app/layers/trigger_post_layer.py index a4019a83e1..77c7bec67e 100644 --- a/api/core/app/layers/trigger_post_layer.py +++ b/api/core/app/layers/trigger_post_layer.py @@ -2,12 +2,12 @@ import logging from datetime import UTC, datetime from typing import Any, ClassVar +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent from pydantic import TypeAdapter from core.db.session_factory import session_factory -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events.base import GraphEngineEvent -from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunPausedEvent, GraphRunSucceededEvent +from core.workflow.system_variables import SystemVariableKey, get_system_text from models.enums import WorkflowTriggerStatus from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository from tasks.workflow_cfs_scheduler.cfs_scheduler import AsyncWorkflowCFSPlanEntity @@ -59,7 +59,10 @@ class TriggerPostLayer(GraphEngineLayer): outputs = self.graph_runtime_state.outputs # BASICLY, workflow_execution_id is the same as workflow_run_id - workflow_run_id = self.graph_runtime_state.system_variable.workflow_execution_id + workflow_run_id = get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ) assert workflow_run_id, "Workflow run id is not set" total_tokens = self.graph_runtime_state.total_tokens diff --git a/api/core/app/llm/model_access.py b/api/core/app/llm/model_access.py index a63ff39fa5..278d0cb30b 100644 --- a/api/core/app/llm/model_access.py +++ b/api/core/app/llm/model_access.py @@ -2,23 +2,35 @@ from __future__ import annotations from typing import Any -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig +from graphon.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError +from graphon.nodes.llm.protocols import CredentialsProvider + +from core.app.entities.app_invoke_entities import DifyRunContext, ModelConfigWithCredentialsEntity from core.errors.error import ProviderTokenNotInitError from core.model_manager import ModelInstance, ModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.nodes.llm.entities import ModelConfig -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory class DifyCredentialsProvider: tenant_id: str provider_manager: ProviderManager - def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None: - self.tenant_id = tenant_id - self.provider_manager = provider_manager or ProviderManager() + def __init__( + self, + *, + run_context: DifyRunContext, + provider_manager: ProviderManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if provider_manager is None: + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + self.provider_manager = provider_manager def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: provider_configurations = self.provider_manager.get_configurations(self.tenant_id) @@ -42,9 +54,21 @@ class DifyModelFactory: tenant_id: str model_manager: ModelManager - def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None: - self.tenant_id = tenant_id - self.model_manager = model_manager or ModelManager() + def __init__( + self, + *, + run_context: DifyRunContext, + model_manager: ModelManager | None = None, + ) -> None: + self.tenant_id = run_context.tenant_id + if model_manager is None: + model_manager = ModelManager( + provider_manager=create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, + ) + ) + self.model_manager = model_manager def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: return self.model_manager.get_model_instance( @@ -55,18 +79,42 @@ class DifyModelFactory: ) -def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]: - return ( - DifyCredentialsProvider(tenant_id=tenant_id), - DifyModelFactory(tenant_id=tenant_id), +def build_dify_model_access(run_context: DifyRunContext) -> tuple[CredentialsProvider, DifyModelFactory]: + """Create LLM access adapters that share the same tenant-bound manager graph.""" + provider_manager = create_plugin_provider_manager( + tenant_id=run_context.tenant_id, + user_id=run_context.user_id, ) + model_manager = ModelManager(provider_manager=provider_manager) + + return ( + DifyCredentialsProvider(run_context=run_context, provider_manager=provider_manager), + DifyModelFactory(run_context=run_context, model_manager=model_manager), + ) + + +def _normalize_completion_params(completion_params: dict[str, Any]) -> tuple[dict[str, Any], list[str]]: + """ + Split node-level completion params into provider parameters and stop sequences. + + Workflow LLM-compatible nodes still consume runtime invocation settings from + ``ModelInstance.parameters`` and ``ModelInstance.stop``. Keep the + ``ModelInstance`` view and the returned config entity aligned here so callers + do not need to duplicate normalization logic. + """ + normalized_parameters = dict(completion_params) + stop = normalized_parameters.pop("stop", []) + if not isinstance(stop, list) or not all(isinstance(item, str) for item in stop): + stop = [] + + return normalized_parameters, stop def fetch_model_config( *, node_data_model: ModelConfig, credentials_provider: CredentialsProvider, - model_factory: ModelFactory, + model_factory: DifyModelFactory, ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: if not node_data_model.mode: raise LLMModeRequiredError("LLM mode is required.") @@ -80,22 +128,18 @@ def fetch_model_config( model_type=ModelType.LLM, ) if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + raise ModelNotExistError(f"Model {node_data_model.name} does not exist.") provider_model.raise_for_status() - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") + if model_schema is None: + raise ModelNotExistError(f"Model {node_data_model.name} schema does not exist.") + parameters, stop = _normalize_completion_params(node_data_model.completion_params) model_instance.provider = node_data_model.provider model_instance.model_name = node_data_model.name model_instance.credentials = credentials - model_instance.parameters = completion_params + model_instance.parameters = parameters model_instance.stop = tuple(stop) return model_instance, ModelConfigWithCredentialsEntity( @@ -103,8 +147,8 @@ def fetch_model_config( model=node_data_model.name, model_schema=model_schema, mode=node_data_model.mode, - provider_model_bundle=provider_model_bundle, credentials=credentials, - parameters=completion_params, + parameters=parameters, stop=stop, + provider_model_bundle=provider_model_bundle, ) diff --git a/api/core/app/llm/quota.py b/api/core/app/llm/quota.py index 7aa3bf15ab..63d2235358 100644 --- a/api/core/app/llm/quota.py +++ b/api/core/app/llm/quota.py @@ -1,3 +1,4 @@ +from graphon.model_runtime.entities.llm_entities import LLMUsage from sqlalchemy import update from sqlalchemy.orm import Session @@ -6,7 +7,6 @@ from core.entities.model_entities import ModelStatus from core.entities.provider_entities import ProviderQuotaType, QuotaUnit from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import LLMUsage from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.provider import Provider, ProviderType diff --git a/api/core/app/task_pipeline/based_generate_task_pipeline.py b/api/core/app/task_pipeline/based_generate_task_pipeline.py index 0d5e0acec6..10b9c36d3e 100644 --- a/api/core/app/task_pipeline/based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/based_generate_task_pipeline.py @@ -1,6 +1,7 @@ import logging import time +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from sqlalchemy import select from sqlalchemy.orm import Session @@ -17,7 +18,6 @@ from core.app.entities.task_entities import ( ) from core.errors.error import QuotaExceededError from core.moderation.output_moderation import ModerationRule, OutputModeration -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus from models.model import Message diff --git a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py index b530fe1ce4..a410fac558 100644 --- a/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py +++ b/api/core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py @@ -4,6 +4,13 @@ from collections.abc import Generator from threading import Thread from typing import Any, Union, cast +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + TextPromptMessageContent, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -51,13 +58,6 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.prompt.utils.prompt_message_util import PromptMessageUtil from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file.enums import FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - TextPromptMessageContent, -) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from events.message_event import message_was_created from extensions.ext_database import db from libs.datetime_utils import naive_utc_now diff --git a/api/core/app/task_pipeline/message_file_utils.py b/api/core/app/task_pipeline/message_file_utils.py index fc8b6c6b5a..b23a33923b 100644 --- a/api/core/app/task_pipeline/message_file_utils.py +++ b/api/core/app/task_pipeline/message_file_utils.py @@ -1,8 +1,9 @@ from typing import TypedDict +from graphon.file import FileTransferMethod +from graphon.file import helpers as file_helpers + from core.tools.signature import sign_tool_file -from dify_graph.file import helpers as file_helpers -from dify_graph.file.enums import FileTransferMethod from models.model import MessageFile, UploadFile MAX_TOOL_FILE_EXTENSION_LENGTH = 10 diff --git a/api/core/app/workflow/file_runtime.py b/api/core/app/workflow/file_runtime.py index e0f8d27111..8604235ef2 100644 --- a/api/core/app/workflow/file_runtime.py +++ b/api/core/app/workflow/file_runtime.py @@ -1,33 +1,43 @@ from __future__ import annotations +import base64 +import hashlib +import hmac +import os +import time +import urllib.parse from collections.abc import Generator +from typing import TYPE_CHECKING, Literal + +from graphon.file import FileTransferMethod +from graphon.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol +from graphon.file.runtime import set_workflow_file_runtime from configs import dify_config +from core.app.file_access import DatabaseFileAccessController, FileAccessControllerProtocol +from core.db.session_factory import session_factory from core.helper.ssrf_proxy import ssrf_proxy from core.tools.signature import sign_tool_file -from dify_graph.file.protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol -from dify_graph.file.runtime import set_workflow_file_runtime +from core.workflow.file_reference import parse_file_reference from extensions.ext_storage import storage +if TYPE_CHECKING: + from graphon.file import File + class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - """Production runtime wiring for ``dify_graph.file``.""" + """Production runtime wiring for ``graphon.file``. - @property - def files_url(self) -> str: - return dify_config.FILES_URL + Opaque file references are resolved back to canonical database records before + URLs are signed or storage keys are used. When a request-scoped file access + scope is present, those lookups additionally enforce tenant and end-user + ownership filters. + """ - @property - def internal_files_url(self) -> str | None: - return dify_config.INTERNAL_FILES_URL + _file_access_controller: FileAccessControllerProtocol - @property - def secret_key(self) -> str: - return dify_config.SECRET_KEY - - @property - def files_access_timeout(self) -> int: - return dify_config.FILES_ACCESS_TIMEOUT + def __init__(self, *, file_access_controller: FileAccessControllerProtocol) -> None: + self._file_access_controller = file_access_controller @property def multimodal_send_format(self) -> str: @@ -39,9 +49,137 @@ class DifyWorkflowFileRuntime(WorkflowFileRuntimeProtocol): def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: return storage.load(path, stream=stream) - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + def load_file_bytes(self, *, file: File) -> bytes: + storage_key = self._resolve_storage_key(file=file) + data = storage.load(storage_key, stream=False) + if not isinstance(data, bytes): + raise ValueError(f"file {storage_key} is not a bytes object") + return data + + def resolve_file_url(self, *, file: File, for_external: bool = True) -> str | None: + if file.transfer_method == FileTransferMethod.REMOTE_URL: + return file.remote_url + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + if file.transfer_method == FileTransferMethod.LOCAL_FILE: + return self.resolve_upload_file_url( + upload_file_id=parsed_reference.record_id, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + self._assert_upload_file_access(upload_file_id=parsed_reference.record_id) + return sign_tool_file( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + if file.transfer_method == FileTransferMethod.TOOL_FILE: + if file.extension is None: + raise ValueError("Missing file extension") + return self.resolve_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + for_external=for_external, + ) + return None + + def resolve_upload_file_url( + self, + *, + upload_file_id: str, + as_attachment: bool = False, + for_external: bool = True, + ) -> str: + self._assert_upload_file_access(upload_file_id=upload_file_id) + base_url = self._base_url(for_external=for_external) + url = f"{base_url}/files/{upload_file_id}/file-preview" + query = self._sign_query(payload=f"file-preview|{upload_file_id}") + if as_attachment: + query["as_attachment"] = "true" + return f"{url}?{urllib.parse.urlencode(query)}" + + def resolve_tool_file_url(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: + self._assert_tool_file_access(tool_file_id=tool_file_id) return sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) + def verify_preview_signature( + self, + *, + preview_kind: Literal["image", "file"], + file_id: str, + timestamp: str, + nonce: str, + sign: str, + ) -> bool: + payload = f"{preview_kind}-preview|{file_id}|{timestamp}|{nonce}" + recalculated = hmac.new(self._secret_key(), payload.encode(), hashlib.sha256).digest() + if sign != base64.urlsafe_b64encode(recalculated).decode(): + return False + return int(time.time()) - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + @staticmethod + def _base_url(*, for_external: bool) -> str: + if for_external: + return dify_config.FILES_URL + return dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + + @staticmethod + def _secret_key() -> bytes: + return dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + + def _sign_query(self, *, payload: str) -> dict[str, str]: + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + sign = hmac.new(self._secret_key(), f"{payload}|{timestamp}|{nonce}".encode(), hashlib.sha256).digest() + return { + "timestamp": timestamp, + "nonce": nonce, + "sign": base64.urlsafe_b64encode(sign).decode(), + } + + def _resolve_storage_key(self, *, file: File) -> str: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("Missing file reference") + + record_id = parsed_reference.record_id + with session_factory.create_session() as session: + if file.transfer_method in { + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + }: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=record_id) + if upload_file is None: + raise ValueError(f"Upload file {record_id} not found") + return upload_file.key + + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=record_id) + if tool_file is None: + raise ValueError(f"Tool file {record_id} not found") + return tool_file.file_key + + def _assert_upload_file_access(self, *, upload_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + upload_file = self._file_access_controller.get_upload_file(session=session, file_id=upload_file_id) + if upload_file is None: + raise ValueError(f"Upload file {upload_file_id} not found") + + def _assert_tool_file_access(self, *, tool_file_id: str) -> None: + if self._file_access_controller.current_scope() is None: + return + + with session_factory.create_session() as session: + tool_file = self._file_access_controller.get_tool_file(session=session, file_id=tool_file_id) + if tool_file is None: + raise ValueError(f"Tool file {tool_file_id} not found") + def bind_dify_workflow_file_runtime() -> None: - set_workflow_file_runtime(DifyWorkflowFileRuntime()) + set_workflow_file_runtime(DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController())) diff --git a/api/core/app/workflow/layers/llm_quota.py b/api/core/app/workflow/layers/llm_quota.py index faf1516c40..48cabaf4d0 100644 --- a/api/core/app/workflow/layers/llm_quota.py +++ b/api/core/app/workflow/layers/llm_quota.py @@ -7,22 +7,22 @@ This layer centralizes model-quota deduction outside node implementations. import logging from typing import TYPE_CHECKING, cast, final +from graphon.enums import BuiltinNodeTypes +from graphon.graph_engine.entities.commands import AbortCommand, CommandType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, NodeRunSucceededEvent +from graphon.nodes.base.node import Node from typing_extensions import override +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.app.llm import deduct_llm_quota, ensure_llm_quota_available from core.errors.error import QuotaExceededError from core.model_manager import ModelInstance -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.nodes.base.node import Node if TYPE_CHECKING: - from dify_graph.nodes.llm.node import LLMNode - from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode - from dify_graph.nodes.question_classifier.question_classifier_node import QuestionClassifierNode + from graphon.nodes.llm.node import LLMNode + from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode + from graphon.nodes.question_classifier.question_classifier_node import QuestionClassifierNode logger = logging.getLogger(__name__) @@ -75,7 +75,7 @@ class LLMQuotaLayer(GraphEngineLayer): return try: - dify_ctx = node.require_dify_context() + dify_ctx = DifyRunContext.model_validate(node.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) deduct_llm_quota( tenant_id=dify_ctx.tenant_id, model_instance=model_instance, @@ -114,11 +114,11 @@ class LLMQuotaLayer(GraphEngineLayer): try: match node.node_type: case BuiltinNodeTypes.LLM: - return cast("LLMNode", node).model_instance + model_instance = cast("LLMNode", node).model_instance case BuiltinNodeTypes.PARAMETER_EXTRACTOR: - return cast("ParameterExtractorNode", node).model_instance + model_instance = cast("ParameterExtractorNode", node).model_instance case BuiltinNodeTypes.QUESTION_CLASSIFIER: - return cast("QuestionClassifierNode", node).model_instance + model_instance = cast("QuestionClassifierNode", node).model_instance case _: return None except AttributeError: @@ -127,3 +127,12 @@ class LLMQuotaLayer(GraphEngineLayer): node.id, ) return None + + if isinstance(model_instance, ModelInstance): + return model_instance + + raw_model_instance = getattr(model_instance, "_model_instance", None) + if isinstance(raw_model_instance, ModelInstance): + return raw_model_instance + + return None diff --git a/api/core/app/workflow/layers/observability.py b/api/core/app/workflow/layers/observability.py index 4b20477a7f..8565c3076c 100644 --- a/api/core/app/workflow/layers/observability.py +++ b/api/core/app/workflow/layers/observability.py @@ -11,15 +11,15 @@ import logging from dataclasses import dataclass from typing import cast, final +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry import context as context_api from opentelemetry.trace import Span, SpanKind, Tracer, get_tracer, set_span_in_context from typing_extensions import override from configs import dify_config -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node from extensions.otel.parser import ( DefaultNodeOTelParser, LLMNodeOTelParser, diff --git a/api/core/app/workflow/layers/persistence.py b/api/core/app/workflow/layers/persistence.py index d95a378575..ada065a943 100644 --- a/api/core/app/workflow/layers/persistence.py +++ b/api/core/app/workflow/layers/persistence.py @@ -14,20 +14,15 @@ from dataclasses import dataclass from datetime import datetime from typing import Any, Union -from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity -from core.ops.entities.trace_entity import TraceTaskName -from core.ops.ops_trace_manager import TraceQueueManager, TraceTask -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities import WorkflowExecution, WorkflowNodeExecution -from dify_graph.enums import ( - SystemVariableKey, +from graphon.entities import WorkflowExecution, WorkflowNodeExecution +from graphon.enums import ( WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, WorkflowType, ) -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from graphon.graph_engine.layers import GraphEngineLayer +from graphon.graph_events import ( GraphEngineEvent, GraphRunAbortedEvent, GraphRunFailedEvent, @@ -42,9 +37,15 @@ from dify_graph.graph_events import ( NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from graphon.node_events import NodeRunResult + +from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity +from core.ops.entities.trace_entity import TraceTaskName +from core.ops.ops_trace_manager import TraceQueueManager, TraceTask +from core.repositories.factory import WorkflowExecutionRepository, WorkflowNodeExecutionRepository +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID +from core.workflow.workflow_run_outputs import project_node_outputs_for_workflow_run from libs.datetime_utils import naive_utc_now @@ -372,10 +373,15 @@ class WorkflowPersistenceLayer(GraphEngineLayer): domain_execution.error = error if update_outputs: + projected_outputs = project_node_outputs_for_workflow_run( + node_type=domain_execution.node_type, + inputs=node_result.inputs, + outputs=node_result.outputs, + ) domain_execution.update_from_mapping( inputs=node_result.inputs, process_data=node_result.process_data, - outputs=node_result.outputs, + outputs=projected_outputs, metadata=node_result.metadata, ) diff --git a/api/core/base/tts/app_generator_tts_publisher.py b/api/core/base/tts/app_generator_tts_publisher.py index beda515666..3d8a7a54f3 100644 --- a/api/core/base/tts/app_generator_tts_publisher.py +++ b/api/core/base/tts/app_generator_tts_publisher.py @@ -6,6 +6,9 @@ import re import threading from collections.abc import Iterable +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.queue_entities import ( MessageQueueMessage, QueueAgentMessageEvent, @@ -15,8 +18,6 @@ from core.app.entities.queue_entities import ( WorkflowQueueMessage, ) from core.model_manager import ModelInstance, ModelManager -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent -from dify_graph.model_runtime.entities.model_entities import ModelType class AudioTrunk: @@ -25,12 +26,10 @@ class AudioTrunk: self.status = status -def _invoice_tts(text_content: str, model_instance: ModelInstance, tenant_id: str, voice: str): +def _invoice_tts(text_content: str, model_instance: ModelInstance, voice: str): if not text_content or text_content.isspace(): return - return model_instance.invoke_tts( - content_text=text_content.strip(), user="responding_tts", tenant_id=tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) def _process_future( @@ -62,7 +61,7 @@ class AppGeneratorTTSPublisher: self._audio_queue: queue.Queue[AudioTrunk] = queue.Queue() self._msg_queue: queue.Queue[WorkflowQueueMessage | MessageQueueMessage | None] = queue.Queue() self.match = re.compile(r"[。.!?]") - self.model_manager = ModelManager() + self.model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id, user_id="responding_tts") self.model_instance = self.model_manager.get_default_model_instance( tenant_id=self.tenant_id, model_type=ModelType.TTS ) @@ -89,7 +88,7 @@ class AppGeneratorTTSPublisher: if message is None: if self.msg_text and len(self.msg_text.strip()) > 0: futures_result = self.executor.submit( - _invoice_tts, self.msg_text, self.model_instance, self.tenant_id, self.voice + _invoice_tts, self.msg_text, self.model_instance, self.voice ) future_queue.put(futures_result) break @@ -117,9 +116,7 @@ class AppGeneratorTTSPublisher: if len(sentence_arr) >= min(self.max_sentence, 7): self.max_sentence += 1 text_content = "".join(sentence_arr) - futures_result = self.executor.submit( - _invoice_tts, text_content, self.model_instance, self.tenant_id, self.voice - ) + futures_result = self.executor.submit(_invoice_tts, text_content, self.model_instance, self.voice) future_queue.put(futures_result) if isinstance(text_tmp, str): self.msg_text = text_tmp diff --git a/api/core/callback_handler/index_tool_callback_handler.py b/api/core/callback_handler/index_tool_callback_handler.py index 8de5cb1690..6a07119244 100644 --- a/api/core/callback_handler/index_tool_callback_handler.py +++ b/api/core/callback_handler/index_tool_callback_handler.py @@ -1,7 +1,7 @@ import logging from collections.abc import Sequence -from sqlalchemy import select +from sqlalchemy import select, update from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom from core.app.entities.app_invoke_entities import InvokeFrom @@ -70,23 +70,21 @@ class DatasetIndexToolCallbackHandler: ) child_chunk = db.session.scalar(child_chunk_stmt) if child_chunk: - _ = ( - db.session.query(DocumentSegment) + db.session.execute( + update(DocumentSegment) .where(DocumentSegment.id == child_chunk.segment_id) - .update( - {DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False - ) + .values(hit_count=DocumentSegment.hit_count + 1) ) else: - query = db.session.query(DocumentSegment).where( - DocumentSegment.index_node_id == document.metadata["doc_id"] - ) + conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]] if "dataset_id" in document.metadata: - query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"]) + conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"]) # add hit count to document segment - query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False) + db.session.execute( + update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1) + ) db.session.commit() diff --git a/api/core/datasource/datasource_file_manager.py b/api/core/datasource/datasource_file_manager.py index 24243add17..fe40d8f0e5 100644 --- a/api/core/datasource/datasource_file_manager.py +++ b/api/core/datasource/datasource_file_manager.py @@ -214,6 +214,6 @@ class DatasourceFileManager: # init tool_file_parser -# from dify_graph.file.datasource_file_parser import datasource_file_manager +# from graphon.file.datasource_file_parser import datasource_file_manager # # datasource_file_manager["manager"] = DatasourceFileManager diff --git a/api/core/datasource/datasource_manager.py b/api/core/datasource/datasource_manager.py index 4fa941ae16..143d1e696b 100644 --- a/api/core/datasource/datasource_manager.py +++ b/api/core/datasource/datasource_manager.py @@ -3,9 +3,13 @@ from collections.abc import Generator from threading import Lock from typing import Any, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType, get_file_type_by_mime_type +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from sqlalchemy import select import contexts +from core.app.file_access import DatabaseFileAccessController from core.datasource.__base.datasource_plugin import DatasourcePlugin from core.datasource.__base.datasource_provider import DatasourcePluginProviderController from core.datasource.entities.datasource_entities import ( @@ -24,18 +28,15 @@ from core.datasource.utils.message_transformer import DatasourceFileMessageTrans from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController from core.db.session_factory import session_factory from core.plugin.impl.datasource import PluginDatasourceManager +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.datasource.entities import DatasourceParameter, OnlineDriveDownloadFileParam -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent from factories import file_factory from models.model import UploadFile from models.tools import ToolFile from services.datasource_provider_service import DatasourceProviderService logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class DatasourceManager: @@ -279,11 +280,15 @@ class DatasourceManager: if datasource_file is not None: mapping = { "tool_file_id": datasource_file_id, - "type": file_factory.get_file_type_by_mime_type(mime_type), + "type": get_file_type_by_mime_type(mime_type), "transfer_method": FileTransferMethod.TOOL_FILE, "url": url, } - file_out = file_factory.build_from_mapping(mapping=mapping, tenant_id=tenant_id) + file_out = file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) elif mtype == DatasourceMessage.MessageType.TEXT: assert isinstance(message.message, DatasourceMessage.TextMessage) yield StreamChunkEvent(selector=[node_id, "text"], chunk=message.message.text, is_final=False) @@ -351,11 +356,10 @@ class DatasourceManager: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.CUSTOM, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference(record_id=str(upload_file.id)), size=upload_file.size, storage_key=upload_file.key, url=upload_file.source_url, diff --git a/api/core/datasource/entities/api_entities.py b/api/core/datasource/entities/api_entities.py index 4c9ff64479..14d1af2e8b 100644 --- a/api/core/datasource/entities/api_entities.py +++ b/api/core/datasource/entities/api_entities.py @@ -1,10 +1,10 @@ from typing import Literal, Optional +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.datasource.entities.datasource_entities import DatasourceParameter from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.utils.encoders import jsonable_encoder class DatasourceApiEntity(BaseModel): diff --git a/api/core/datasource/utils/message_transformer.py b/api/core/datasource/utils/message_transformer.py index 2881888e27..04f15dee31 100644 --- a/api/core/datasource/utils/message_transformer.py +++ b/api/core/datasource/utils/message_transformer.py @@ -2,9 +2,11 @@ import logging from collections.abc import Generator from mimetypes import guess_extension, guess_type +from graphon.file import File, FileTransferMethod, FileType + from core.datasource.entities.datasource_entities import DatasourceMessage from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType +from core.workflow.file_reference import parse_file_reference from models.tools import ToolFile logger = logging.getLogger(__name__) @@ -103,8 +105,14 @@ class DatasourceFileMessageTransformer: file: File | None = meta.get("file") if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_datasource_file_url(datasource_file_id=file.related_id, extension=file.extension) + reference = getattr(file, "reference", None) or getattr(file, "related_id", None) + parsed_reference = parse_file_reference(reference) if isinstance(reference, str) else None + if parsed_reference is None: + raise ValueError("datasource file is missing reference") + url = cls.get_datasource_file_url( + datasource_file_id=parsed_reference.record_id, + extension=file.extension, + ) if file.type == FileType.IMAGE: yield DatasourceMessage( type=DatasourceMessage.MessageType.IMAGE_LINK, diff --git a/api/core/entities/embedding_type.py b/api/core/entities/embedding_type.py index 89b48fd2ef..f49cbf9ffe 100644 --- a/api/core/entities/embedding_type.py +++ b/api/core/entities/embedding_type.py @@ -1,10 +1,5 @@ -from enum import StrEnum, auto +"""Compatibility wrapper for the runtime embedding input enum.""" +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType -class EmbeddingInputType(StrEnum): - """ - Enum for embedding input type. - """ - - DOCUMENT = auto() - QUERY = auto() +__all__ = ["EmbeddingInputType"] diff --git a/api/core/entities/execution_extra_content.py b/api/core/entities/execution_extra_content.py index 1343bd8e82..72f6590e68 100644 --- a/api/core/entities/execution_extra_content.py +++ b/api/core/entities/execution_extra_content.py @@ -3,9 +3,9 @@ from __future__ import annotations from collections.abc import Mapping, Sequence from typing import Any, TypeAlias +from graphon.nodes.human_input.entities import FormInput, UserAction from pydantic import BaseModel, ConfigDict, Field -from dify_graph.nodes.human_input.entities import FormInput, UserAction from models.execution_extra_content import ExecutionContentType diff --git a/api/core/entities/mcp_provider.py b/api/core/entities/mcp_provider.py index d214652e9c..a440829b46 100644 --- a/api/core/entities/mcp_provider.py +++ b/api/core/entities/mcp_provider.py @@ -6,6 +6,7 @@ from enum import StrEnum from typing import TYPE_CHECKING, Any from urllib.parse import urlparse +from graphon.file import helpers as file_helpers from pydantic import BaseModel from configs import dify_config @@ -15,7 +16,6 @@ from core.helper.provider_cache import NoOpProviderCredentialCache from core.mcp.types import OAuthClientInformation, OAuthClientMetadata, OAuthTokens from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.file import helpers as file_helpers if TYPE_CHECKING: from models.tools import MCPToolProvider diff --git a/api/core/entities/model_entities.py b/api/core/entities/model_entities.py index 3427fc54b1..84d95c38c6 100644 --- a/api/core/entities/model_entities.py +++ b/api/core/entities/model_entities.py @@ -1,12 +1,11 @@ from collections.abc import Sequence from enum import StrEnum, auto +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType, ProviderModel +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType, ProviderModel -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity - class ModelStatus(StrEnum): """ diff --git a/api/core/entities/provider_configuration.py b/api/core/entities/provider_configuration.py index a9f2300ba2..8b48aa2660 100644 --- a/api/core/entities/provider_configuration.py +++ b/api/core/entities/provider_configuration.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import json import logging import re @@ -5,7 +7,17 @@ from collections import defaultdict from collections.abc import Iterator, Sequence from json import JSONDecodeError -from pydantic import BaseModel, ConfigDict, Field, model_validator +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.ai_model import AIModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.model_runtime.runtime import ModelRuntime +from pydantic import BaseModel, ConfigDict, Field, PrivateAttr, model_validator from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -19,15 +31,7 @@ from core.entities.provider_entities import ( ) from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from libs.datetime_utils import naive_utc_now from models.engine import db from models.enums import CredentialSourceType @@ -60,6 +64,10 @@ class ProviderConfiguration(BaseModel): - Load balancing configurations - Model enablement/disablement + Request flows can bind a pre-scoped runtime via ``bind_model_runtime()`` so + nested schema and model lookups reuse the caller scope that was already + resolved by the composition layer. + TODO: lots of logic in a BaseModel entity should be separated, the exceptions should be classified """ @@ -73,6 +81,7 @@ class ProviderConfiguration(BaseModel): # pydantic configs model_config = ConfigDict(protected_namespaces=()) + _bound_model_runtime: ModelRuntime | None = PrivateAttr(default=None) @model_validator(mode="after") def _(self): @@ -92,6 +101,16 @@ class ProviderConfiguration(BaseModel): self.provider.configurate_methods.append(ConfigurateMethod.PREDEFINED_MODEL) return self + def bind_model_runtime(self, model_runtime: ModelRuntime) -> None: + """Attach the already-composed runtime for request-bound call chains.""" + self._bound_model_runtime = model_runtime + + def get_model_provider_factory(self) -> ModelProviderFactory: + """Return a provider factory that preserves any request-bound runtime.""" + if self._bound_model_runtime is not None: + return ModelProviderFactory(model_runtime=self._bound_model_runtime) + return create_plugin_model_provider_factory(tenant_id=self.tenant_id) + def get_current_credentials(self, model_type: ModelType, model: str) -> dict | None: """ Get current credentials. @@ -343,7 +362,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.provider_credentials_validate( provider=self.provider.provider, credentials=credentials ) @@ -902,7 +921,7 @@ class ProviderConfiguration(BaseModel): tenant_id=self.tenant_id, token=original_credentials[key] ) - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() validated_credentials = model_provider_factory.model_credentials_validate( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1388,7 +1407,7 @@ class ProviderConfiguration(BaseModel): :param model_type: model type :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() # Get model instance of LLM return model_provider_factory.get_model_type_instance(provider=self.provider.provider, model_type=model_type) @@ -1397,7 +1416,7 @@ class ProviderConfiguration(BaseModel): """ Get model schema """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() return model_provider_factory.get_model_schema( provider=self.provider.provider, model_type=model_type, model=model, credentials=credentials ) @@ -1499,7 +1518,7 @@ class ProviderConfiguration(BaseModel): :param model: model name :return: """ - model_provider_factory = ModelProviderFactory(self.tenant_id) + model_provider_factory = self.get_model_provider_factory() provider_schema = model_provider_factory.get_provider_schema(self.provider.provider) model_types: list[ModelType] = [] diff --git a/api/core/entities/provider_entities.py b/api/core/entities/provider_entities.py index a830f227a9..2c8767a32b 100644 --- a/api/core/entities/provider_entities.py +++ b/api/core/entities/provider_entities.py @@ -3,6 +3,7 @@ from __future__ import annotations from enum import StrEnum, auto from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, ConfigDict, Field from core.entities.parameter_entities import ( @@ -12,7 +13,6 @@ from core.entities.parameter_entities import ( ToolSelectorScope, ) from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType class ProviderQuotaType(StrEnum): diff --git a/api/core/helper/code_executor/code_executor.py b/api/core/helper/code_executor/code_executor.py index 4251cfd30b..35bfcfb6a5 100644 --- a/api/core/helper/code_executor/code_executor.py +++ b/api/core/helper/code_executor/code_executor.py @@ -4,6 +4,7 @@ from threading import Lock from typing import Any import httpx +from graphon.nodes.code.entities import CodeLanguage from pydantic import BaseModel from yarl import URL @@ -13,7 +14,6 @@ from core.helper.code_executor.jinja2.jinja2_transformer import Jinja2TemplateTr from core.helper.code_executor.python3.python3_transformer import Python3TemplateTransformer from core.helper.code_executor.template_transformer import TemplateTransformer from core.helper.http_client_pooling import get_pooled_http_client -from dify_graph.nodes.code.entities import CodeLanguage logger = logging.getLogger(__name__) code_execution_endpoint_url = URL(str(dify_config.CODE_EXECUTION_ENDPOINT)) diff --git a/api/core/helper/code_executor/template_transformer.py b/api/core/helper/code_executor/template_transformer.py index c569e066f4..b96a9ce380 100644 --- a/api/core/helper/code_executor/template_transformer.py +++ b/api/core/helper/code_executor/template_transformer.py @@ -5,7 +5,7 @@ from base64 import b64encode from collections.abc import Mapping from typing import Any -from dify_graph.variables.utils import dumps_with_segments +from graphon.variables.utils import dumps_with_segments class TemplateTransformer(ABC): diff --git a/api/core/helper/encrypter.py b/api/core/helper/encrypter.py index 17345dc203..20125ec6b3 100644 --- a/api/core/helper/encrypter.py +++ b/api/core/helper/encrypter.py @@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str): from extensions.ext_database import db from models.account import Tenant - if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()): + if not (tenant := db.session.get(Tenant, tenant_id)): raise ValueError(f"Tenant with id {tenant_id} not found") assert tenant.encrypt_public_key is not None encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key) diff --git a/api/core/helper/moderation.py b/api/core/helper/moderation.py index 873f6a4093..a1e782a094 100644 --- a/api/core/helper/moderation.py +++ b/api/core/helper/moderation.py @@ -2,12 +2,13 @@ import logging import secrets from typing import cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeBadRequestError +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities import DEFAULT_PLUGIN_ID -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from extensions.ext_hosting_provider import hosting_configuration from models.provider import ProviderType @@ -41,7 +42,7 @@ def check_moderation(tenant_id: str, model_config: ModelConfigWithCredentialsEnt text_chunk = secrets.choice(text_chunks) try: - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) # Get model instance of LLM model_type_instance = model_provider_factory.get_model_type_instance( diff --git a/api/core/hosting_configuration.py b/api/core/hosting_configuration.py index 600a444357..60f5434bc1 100644 --- a/api/core/hosting_configuration.py +++ b/api/core/hosting_configuration.py @@ -1,10 +1,10 @@ from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel from configs import dify_config from core.entities import DEFAULT_PLUGIN_ID from core.entities.provider_entities import ProviderQuotaType, QuotaUnit, RestrictModel -from dify_graph.model_runtime.entities.model_entities import ModelType class HostingQuota(BaseModel): diff --git a/api/core/indexing_runner.py b/api/core/indexing_runner.py index 06bc366081..3ec17bc986 100644 --- a/api/core/indexing_runner.py +++ b/api/core/indexing_runner.py @@ -9,6 +9,7 @@ from collections.abc import Mapping from typing import Any from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm.exc import ObjectDeletedError @@ -31,7 +32,6 @@ from core.rag.splitter.fixed_text_splitter import ( ) from core.rag.splitter.text_splitter import TextSplitter from core.tools.utils.web_reader_tool import get_image_upload_file_ids -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -50,7 +50,10 @@ logger = logging.getLogger(__name__) class IndexingRunner: def __init__(self): self.storage = storage - self.model_manager = ModelManager() + + @staticmethod + def _get_model_manager(tenant_id: str) -> ModelManager: + return ModelManager.for_tenant(tenant_id=tenant_id) def _handle_indexing_error(self, document_id: str, error: Exception) -> None: """Handle indexing errors by updating document status.""" @@ -291,20 +294,20 @@ class IndexingRunner: raise ValueError("Dataset not found.") if IndexTechniqueType.HIGH_QUALITY in {dataset.indexing_technique, indexing_technique}: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_model_instance( tenant_id=tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) else: if indexing_technique == IndexTechniqueType.HIGH_QUALITY: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(tenant_id).get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) @@ -574,7 +577,7 @@ class IndexingRunner: embedding_model_instance = None if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -766,14 +769,14 @@ class IndexingRunner: embedding_model_instance = None if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: if dataset.embedding_model_provider: - embedding_model_instance = self.model_manager.get_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, model=dataset.embedding_model, ) else: - embedding_model_instance = self.model_manager.get_default_model_instance( + embedding_model_instance = self._get_model_manager(dataset.tenant_id).get_default_model_instance( tenant_id=dataset.tenant_id, model_type=ModelType.TEXT_EMBEDDING, ) diff --git a/api/core/llm_generator/llm_generator.py b/api/core/llm_generator/llm_generator.py index c8848336d9..d39630ad95 100644 --- a/api/core/llm_generator/llm_generator.py +++ b/api/core/llm_generator/llm_generator.py @@ -5,6 +5,12 @@ from collections.abc import Sequence from typing import Protocol, cast import json_repair +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError +from sqlalchemy import select from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload @@ -27,11 +33,6 @@ from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from extensions.ext_database import db from extensions.ext_storage import storage from models import App, Message, WorkflowNodeExecutionModel @@ -62,7 +63,7 @@ class LLMGenerator: prompt += query + "\n" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -120,7 +121,7 @@ class LLMGenerator: prompt = prompt_template.format({"histories": histories, "format_instructions": format_instructions}) try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -172,7 +173,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate)] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, @@ -219,7 +220,7 @@ class LLMGenerator: prompt_messages = [UserPromptMessage(content=prompt_generate_prompt)] # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -306,7 +307,7 @@ class LLMGenerator: remove_template_variables=False, ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -337,7 +338,7 @@ class LLMGenerator: def generate_qa_document(cls, tenant_id: str, query, document_language: str): prompt = GENERATOR_QA_PROMPT.format(language=document_language) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -362,7 +363,7 @@ class LLMGenerator: @classmethod def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -410,8 +411,8 @@ class LLMGenerator: model_config: ModelConfig, ideal_output: str | None, ): - last_run: Message | None = ( - db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first() + last_run: Message | None = db.session.scalar( + select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1) ) if not last_run: return LLMGenerator.__instruction_modify_common( @@ -536,7 +537,7 @@ class LLMGenerator: injected_instruction = injected_instruction.replace(CURRENT, current or "null") if ERROR_MESSAGE in injected_instruction: injected_instruction = injected_instruction.replace(ERROR_MESSAGE, error_message or "null") - model_instance = ModelManager().get_model_instance( + model_instance = ModelManager.for_tenant(tenant_id=tenant_id).get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, diff --git a/api/core/llm_generator/output_parser/structured_output.py b/api/core/llm_generator/output_parser/structured_output.py index 77ea1713ea..a1710f11ac 100644 --- a/api/core/llm_generator/output_parser/structured_output.py +++ b/api/core/llm_generator/output_parser/structured_output.py @@ -5,27 +5,27 @@ from enum import StrEnum from typing import Any, Literal, cast, overload import json_repair -from pydantic import TypeAdapter, ValidationError - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT -from core.model_manager import ModelInstance -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMResultChunkWithStructuredOutput, LLMResultWithStructuredOutput, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageTool, SystemPromptMessage, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule +from pydantic import TypeAdapter, ValidationError + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT +from core.model_manager import ModelInstance class ResponseFormat(StrEnum): @@ -55,7 +55,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ... @overload @@ -70,7 +69,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False], - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput: ... @overload @@ -85,7 +83,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ... def invoke_llm_with_structured_output( @@ -99,7 +96,6 @@ def invoke_llm_with_structured_output( tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: """ @@ -113,7 +109,6 @@ def invoke_llm_with_structured_output( :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -143,7 +138,6 @@ def invoke_llm_with_structured_output( tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ) diff --git a/api/core/mcp/server/streamable_http.py b/api/core/mcp/server/streamable_http.py index de68eb268b..27000c947c 100644 --- a/api/core/mcp/server/streamable_http.py +++ b/api/core/mcp/server/streamable_http.py @@ -3,11 +3,12 @@ import logging from collections.abc import Mapping from typing import Any, cast +from graphon.variables.input_entities import VariableEntity, VariableEntityType + from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types as mcp_types -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser from services.app_generate_service import AppGenerateService diff --git a/api/core/mcp/utils.py b/api/core/mcp/utils.py index db9cb726d7..7e35044176 100644 --- a/api/core/mcp/utils.py +++ b/api/core/mcp/utils.py @@ -4,11 +4,11 @@ from contextlib import AbstractContextManager import httpx import httpx_sse +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx_sse import connect_sse from configs import dify_config from core.mcp.types import ErrorData, JSONRPCError -from dify_graph.model_runtime.utils.encoders import jsonable_encoder HTTP_REQUEST_NODE_SSL_VERIFY = dify_config.HTTP_REQUEST_NODE_SSL_VERIFY diff --git a/api/core/memory/token_buffer_memory.py b/api/core/memory/token_buffer_memory.py index 1156a98af1..09c84538a9 100644 --- a/api/core/memory/token_buffer_memory.py +++ b/api/core/memory/token_buffer_memory.py @@ -1,13 +1,7 @@ from collections.abc import Sequence -from sqlalchemy import select -from sqlalchemy.orm import sessionmaker - -from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from core.model_manager import ModelInstance -from core.prompt.utils.extract_thread_messages import extract_thread_messages -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, @@ -15,7 +9,14 @@ from dify_graph.model_runtime.entities import ( TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes +from sqlalchemy import select +from sqlalchemy.orm import sessionmaker + +from core.app.app_config.features.file_upload.manager import FileUploadConfigManager +from core.app.file_access import DatabaseFileAccessController +from core.model_manager import ModelInstance +from core.prompt.utils.extract_thread_messages import extract_thread_messages from extensions.ext_database import db from factories import file_factory from models.model import AppMode, Conversation, Message, MessageFile @@ -23,6 +24,8 @@ from models.workflow import Workflow from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.factory import DifyAPIRepositoryFactory +_file_access_controller = DatabaseFileAccessController() + class TokenBufferMemory: def __init__( @@ -85,7 +88,10 @@ class TokenBufferMemory: # Build files directly without filtering by belongs_to file_objs = [ file_factory.build_from_message_file( - message_file=message_file, tenant_id=app_record.tenant_id, config=file_extra_config + message_file=message_file, + tenant_id=app_record.tenant_id, + config=file_extra_config, + access_controller=_file_access_controller, ) for message_file in message_files ] diff --git a/api/core/model_manager.py b/api/core/model_manager.py index 0f710a8fcf..87d1d7fba6 100644 --- a/api/core/model_manager.py +++ b/api/core/model_manager.py @@ -2,25 +2,27 @@ import logging from collections.abc import Callable, Generator, Iterable, Mapping, Sequence from typing import IO, Any, Literal, Optional, Union, cast, overload +from graphon.model_runtime.callbacks.base_callback import Callback +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel + from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.errors.error import ProviderTokenNotInitError +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel -from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel from extensions.ext_redis import redis_client from models.provider import ProviderType from services.enterprise.plugin_manager_service import PluginCredentialType @@ -30,7 +32,7 @@ logger = logging.getLogger(__name__) class ModelInstance: """ - Model instance class + Model instance class. """ def __init__(self, provider_model_bundle: ProviderModelBundle, model: str): @@ -49,6 +51,13 @@ class ModelInstance: credentials=self.credentials, ) + def get_model_schema(self) -> AIModelEntity: + """Return the resolved schema for the current model instance.""" + model_schema = self.model_type_instance.get_model_schema(self.model_name, self.credentials) + if model_schema is None: + raise ValueError(f"model schema not found for {self.model_name}") + return model_schema + @staticmethod def _fetch_credentials_from_bundle(provider_model_bundle: ProviderModelBundle, model: str): """ @@ -110,7 +119,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[True] = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Generator: ... @@ -122,7 +130,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: Literal[False] = False, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> LLMResult: ... @@ -134,7 +141,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: list[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: ... @@ -145,7 +151,6 @@ class ModelInstance: tools: Sequence[PromptMessageTool] | None = None, stop: Sequence[str] | None = None, stream: bool = True, - user: str | None = None, callbacks: list[Callback] | None = None, ) -> Union[LLMResult, Generator]: """ @@ -156,7 +161,6 @@ class ModelInstance: :param tools: tools for tool calling :param stop: stop words :param stream: is stream response - :param user: unique user id :param callbacks: callbacks :return: full response or stream response chunk generator result """ @@ -173,7 +177,6 @@ class ModelInstance: tools=tools, stop=stop, stream=stream, - user=user, callbacks=callbacks, ), ) @@ -202,13 +205,12 @@ class ModelInstance: ) def invoke_text_embedding( - self, texts: list[str], user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT + self, texts: list[str], input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT ) -> EmbeddingResult: """ Invoke large language model :param texts: texts to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -221,7 +223,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, texts=texts, - user=user, input_type=input_type, ), ) @@ -229,14 +230,12 @@ class ModelInstance: def invoke_multimodal_embedding( self, multimodel_documents: list[dict], - user: str | None = None, input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, ) -> EmbeddingResult: """ Invoke large language model :param multimodel_documents: multimodel documents to embed - :param user: unique user id :param input_type: input type :return: embeddings result """ @@ -249,7 +248,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, multimodel_documents=multimodel_documents, - user=user, input_type=input_type, ), ) @@ -279,7 +277,6 @@ class ModelInstance: docs: list[str], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -288,7 +285,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -303,7 +299,6 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) @@ -313,7 +308,6 @@ class ModelInstance: docs: list[dict], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> RerankResult: """ Invoke rerank model @@ -322,7 +316,6 @@ class ModelInstance: :param docs: docs for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id :return: rerank result """ if not isinstance(self.model_type_instance, RerankModel): @@ -337,16 +330,14 @@ class ModelInstance: docs=docs, score_threshold=score_threshold, top_n=top_n, - user=user, ), ) - def invoke_moderation(self, text: str, user: str | None = None) -> bool: + def invoke_moderation(self, text: str) -> bool: """ Invoke moderation model :param text: text to moderate - :param user: unique user id :return: false if text is safe, true otherwise """ if not isinstance(self.model_type_instance, ModerationModel): @@ -358,16 +349,14 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, text=text, - user=user, ), ) - def invoke_speech2text(self, file: IO[bytes], user: str | None = None) -> str: + def invoke_speech2text(self, file: IO[bytes]) -> str: """ Invoke large language model :param file: audio file - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, Speech2TextModel): @@ -379,18 +368,15 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, file=file, - user=user, ), ) - def invoke_tts(self, content_text: str, tenant_id: str, voice: str, user: str | None = None) -> Iterable[bytes]: + def invoke_tts(self, content_text: str, voice: str = "") -> Iterable[bytes]: """ Invoke large language tts model :param content_text: text content to be translated - :param tenant_id: user tenant id :param voice: model timbre - :param user: unique user id :return: text for given audio file """ if not isinstance(self.model_type_instance, TTSModel): @@ -402,8 +388,6 @@ class ModelInstance: model=self.model_name, credentials=self.credentials, content_text=content_text, - user=user, - tenant_id=tenant_id, voice=voice, ), ) @@ -477,10 +461,20 @@ class ModelInstance: class ModelManager: - def __init__(self): - self._provider_manager = ProviderManager() + def __init__(self, provider_manager: ProviderManager): + self._provider_manager = provider_manager - def get_model_instance(self, tenant_id: str, provider: str, model_type: ModelType, model: str) -> ModelInstance: + @classmethod + def for_tenant(cls, tenant_id: str, user_id: str | None = None) -> "ModelManager": + return cls(provider_manager=create_plugin_provider_manager(tenant_id=tenant_id, user_id=user_id)) + + def get_model_instance( + self, + tenant_id: str, + provider: str, + model_type: ModelType, + model: str, + ) -> ModelInstance: """ Get model instance :param tenant_id: tenant id @@ -496,7 +490,8 @@ class ModelManager: tenant_id=tenant_id, provider=provider, model_type=model_type ) - return ModelInstance(provider_model_bundle, model) + model_instance = ModelInstance(provider_model_bundle, model) + return model_instance def get_default_provider_model_name(self, tenant_id: str, model_type: ModelType) -> tuple[str | None, str | None]: """ diff --git a/api/core/moderation/openai_moderation/openai_moderation.py b/api/core/moderation/openai_moderation/openai_moderation.py index 06676f5cf4..dd038c77f1 100644 --- a/api/core/moderation/openai_moderation/openai_moderation.py +++ b/api/core/moderation/openai_moderation/openai_moderation.py @@ -1,6 +1,7 @@ +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.moderation.base import Moderation, ModerationAction, ModerationInputsResult, ModerationOutputsResult -from dify_graph.model_runtime.entities.model_entities import ModelType class OpenAIModeration(Moderation): @@ -50,7 +51,7 @@ class OpenAIModeration(Moderation): def _is_violated(self, inputs: dict): text = "\n".join(str(inputs.values())) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider="openai", model_type=ModelType.MODERATION, model="omni-moderation-latest" ) diff --git a/api/core/ops/aliyun_trace/aliyun_trace.py b/api/core/ops/aliyun_trace/aliyun_trace.py index 18f35b5b9c..70aaf2a07b 100644 --- a/api/core/ops/aliyun_trace/aliyun_trace.py +++ b/api/core/ops/aliyun_trace/aliyun_trace.py @@ -1,6 +1,8 @@ import logging from collections.abc import Sequence +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import SpanKind from sqlalchemy.orm import sessionmaker @@ -57,8 +59,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import WorkflowNodeExecutionTriggeredFrom @@ -296,7 +296,9 @@ class AliyunDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - return workflow_node_execution_repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + return workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id + ) def build_workflow_node_span( self, node_execution: WorkflowNodeExecution, trace_info: WorkflowTraceInfo, trace_metadata: TraceMetadata diff --git a/api/core/ops/aliyun_trace/utils.py b/api/core/ops/aliyun_trace/utils.py index 45319f24c1..d8e105d6a3 100644 --- a/api/core/ops/aliyun_trace/utils.py +++ b/api/core/ops/aliyun_trace/utils.py @@ -2,6 +2,8 @@ import json from collections.abc import Mapping from typing import Any +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, Status, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -14,8 +16,6 @@ from core.ops.aliyun_trace.entities.semconv import ( GenAISpanKind, ) from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.ext_database import db from models import EndUser @@ -27,9 +27,7 @@ DEFAULT_FRAMEWORK_NAME = "dify" def get_user_id_from_message_data(message_data) -> str: user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id return user_id diff --git a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py index f54461e99a..39d97e2882 100644 --- a/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py +++ b/api/core/ops/arize_phoenix_trace/arize_phoenix_trace.py @@ -6,6 +6,7 @@ from datetime import datetime, timedelta from typing import Any, Union, cast from urllib.parse import urlparse +from graphon.enums import WorkflowNodeExecutionStatus from openinference.semconv.trace import ( MessageAttributes, OpenInferenceMimeTypeValues, @@ -271,8 +272,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) try: @@ -300,7 +301,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): "app_name": node_execution.title, "status": node_execution.status, "status_message": node_execution.error or "", - "level": "ERROR" if node_execution.status == "failed" else "DEFAULT", + "level": "ERROR" if node_execution.status == WorkflowNodeExecutionStatus.FAILED else "DEFAULT", } ) @@ -361,7 +362,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): llm_attributes.update(self._construct_llm_attributes(process_data.get("prompts", []))) node_span.set_attributes(llm_attributes) finally: - if node_execution.status == "failed": + if node_execution.status == WorkflowNodeExecutionStatus.FAILED: set_span_status(node_span, node_execution.error) else: set_span_status(node_span) @@ -409,9 +410,7 @@ class ArizePhoenixDataTrace(BaseTraceInstance): # Add end user data if available if trace_info.message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == trace_info.message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, trace_info.message_data.from_end_user_id) if end_user_data is not None: metadata["end_user_id"] = end_user_data.session_id diff --git a/api/core/ops/entities/trace_entity.py b/api/core/ops/entities/trace_entity.py index 50a2cdea63..45b2f635ba 100644 --- a/api/core/ops/entities/trace_entity.py +++ b/api/core/ops/entities/trace_entity.py @@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator class BaseTraceInfo(BaseModel): message_id: str | None = None message_data: Any | None = None - inputs: Union[str, dict[str, Any], list] | None = None - outputs: Union[str, dict[str, Any], list] | None = None + inputs: Union[str, dict[str, Any], list[Any]] | None = None + outputs: Union[str, dict[str, Any], list[Any]] | None = None start_time: datetime | None = None end_time: datetime | None = None metadata: dict[str, Any] @@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel): @field_validator("inputs", "outputs") @classmethod - def ensure_type(cls, v): + def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None: if v is None: return None if isinstance(v, str | dict | list): @@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel): model_config = ConfigDict(protected_namespaces=()) + @property + def resolved_trace_id(self) -> str | None: + """Get trace_id with intelligent fallback. + + Priority: + 1. External trace_id (from X-Trace-Id header) + 2. workflow_run_id (if this trace type has it) + 3. message_id (as final fallback) + """ + if self.trace_id: + return self.trace_id + + # Try workflow_run_id (only exists on workflow-related traces) + workflow_run_id = getattr(self, "workflow_run_id", None) + if workflow_run_id: + return workflow_run_id + + # Final fallback to message_id + return str(self.message_id) if self.message_id else None + + @property + def resolved_parent_context(self) -> tuple[str | None, str | None]: + """Resolve cross-workflow parent linking from metadata. + + Extracts typed parent IDs from the untyped ``parent_trace_context`` + metadata dict (set by tool_node when invoking nested workflows). + + Returns: + (trace_correlation_override, parent_span_id_source) where + trace_correlation_override is the outer workflow_run_id and + parent_span_id_source is the outer node_execution_id. + """ + parent_ctx = self.metadata.get("parent_trace_context") + if not isinstance(parent_ctx, dict): + return None, None + trace_override = parent_ctx.get("parent_workflow_run_id") + parent_span = parent_ctx.get("parent_node_execution_id") + return ( + trace_override if isinstance(trace_override, str) else None, + parent_span if isinstance(parent_span, str) else None, + ) + @field_serializer("start_time", "end_time") def serialize_datetime(self, dt: datetime | None) -> str | None: if dt is None: @@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo): workflow_run_version: str error: str | None = None total_tokens: int + prompt_tokens: int | None = None + completion_tokens: int | None = None file_list: list[str] + invoked_by: str | None = None query: str metadata: dict[str, Any] @@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo): answer_tokens: int total_tokens: int error: str | None = None - file_list: Union[str, dict[str, Any], list] | None = None + file_list: Union[str, dict[str, Any], list[Any]] | None = None message_file_data: Any | None = None conversation_mode: str gen_ai_server_time_to_first_token: float | None = None @@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo): tool_config: dict[str, Any] time_cost: Union[int, float] tool_parameters: dict[str, Any] - file_url: Union[str, None, list] = None + file_url: Union[str, None, list[str]] = None class GenerateNameTraceInfo(BaseTraceInfo): @@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo): tenant_id: str +class PromptGenerationTraceInfo(BaseTraceInfo): + """Trace information for prompt generation operations (rule-generate, code-generate, etc.).""" + + tenant_id: str + user_id: str + app_id: str | None = None + + operation_type: str + instruction: str + + prompt_tokens: int + completion_tokens: int + total_tokens: int + + model_provider: str + model_name: str + + latency: float + + total_price: float | None = None + currency: str | None = None + + error: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class WorkflowNodeTraceInfo(BaseTraceInfo): + workflow_id: str + workflow_run_id: str + tenant_id: str + node_execution_id: str + node_id: str + node_type: str + title: str + + status: str + error: str | None = None + elapsed_time: float + + index: int + predecessor_node_id: str | None = None + + total_tokens: int = 0 + total_price: float = 0.0 + currency: str | None = None + + model_provider: str | None = None + model_name: str | None = None + prompt_tokens: int | None = None + completion_tokens: int | None = None + + tool_name: str | None = None + + iteration_id: str | None = None + iteration_index: int | None = None + loop_id: str | None = None + loop_index: int | None = None + parallel_id: str | None = None + + node_inputs: Mapping[str, Any] | None = None + node_outputs: Mapping[str, Any] | None = None + process_data: Mapping[str, Any] | None = None + + invoked_by: str | None = None + + model_config = ConfigDict(protected_namespaces=()) + + +class DraftNodeExecutionTrace(WorkflowNodeTraceInfo): + pass + + class TaskData(BaseModel): app_id: str trace_info_type: str @@ -128,11 +246,31 @@ trace_info_info_map = { "DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo, "ToolTraceInfo": ToolTraceInfo, "GenerateNameTraceInfo": GenerateNameTraceInfo, + "PromptGenerationTraceInfo": PromptGenerationTraceInfo, + "WorkflowNodeTraceInfo": WorkflowNodeTraceInfo, + "DraftNodeExecutionTrace": DraftNodeExecutionTrace, } +class OperationType(StrEnum): + """Operation type for token metric labels. + + Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output`` + counters so consumers can break down token usage by operation. + """ + + WORKFLOW = "workflow" + NODE_EXECUTION = "node_execution" + MESSAGE = "message" + RULE_GENERATE = "rule_generate" + CODE_GENERATE = "code_generate" + STRUCTURED_OUTPUT = "structured_output" + INSTRUCTION_MODIFY = "instruction_modify" + + class TraceTaskName(StrEnum): CONVERSATION_TRACE = "conversation" + DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution" WORKFLOW_TRACE = "workflow" MESSAGE_TRACE = "message" MODERATION_TRACE = "moderation" @@ -140,4 +278,6 @@ class TraceTaskName(StrEnum): DATASET_RETRIEVAL_TRACE = "dataset_retrieval" TOOL_TRACE = "tool" GENERATE_NAME_TRACE = "generate_conversation_name" + PROMPT_GENERATION_TRACE = "prompt_generation" + NODE_EXECUTION_TRACE = "node_execution" DATASOURCE_TRACE = "datasource" diff --git a/api/core/ops/langfuse_trace/langfuse_trace.py b/api/core/ops/langfuse_trace/langfuse_trace.py index 6e62387a1f..3644b6b4c2 100644 --- a/api/core/ops/langfuse_trace/langfuse_trace.py +++ b/api/core/ops/langfuse_trace/langfuse_trace.py @@ -2,6 +2,7 @@ import logging import os from datetime import datetime, timedelta +from graphon.enums import BuiltinNodeTypes from langfuse import Langfuse from sqlalchemy.orm import sessionmaker @@ -28,7 +29,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( ) from core.ops.utils import filter_none_values from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db from models import EndUser, WorkflowNodeExecutionTriggeredFrom from models.enums import MessageStatus @@ -130,8 +130,8 @@ class LangFuseDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: @@ -241,9 +241,7 @@ class LangFuseDataTrace(BaseTraceInstance): user_id = message_data.from_account_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: user_id = end_user_data.session_id metadata["user_id"] = user_id diff --git a/api/core/ops/langsmith_trace/langsmith_trace.py b/api/core/ops/langsmith_trace/langsmith_trace.py index 32a0c77fe2..490c64af84 100644 --- a/api/core/ops/langsmith_trace/langsmith_trace.py +++ b/api/core/ops/langsmith_trace/langsmith_trace.py @@ -4,6 +4,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from langsmith import Client from langsmith.schemas import RunBase from sqlalchemy.orm import sessionmaker @@ -28,7 +29,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( ) from core.ops.utils import filter_none_values, generate_dotted_order from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -152,8 +152,8 @@ class LangSmithDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: @@ -259,9 +259,7 @@ class LangSmithDataTrace(BaseTraceInstance): metadata["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/mlflow_trace/mlflow_trace.py b/api/core/ops/mlflow_trace/mlflow_trace.py index ab4a7650ec..946d3cdd47 100644 --- a/api/core/ops/mlflow_trace/mlflow_trace.py +++ b/api/core/ops/mlflow_trace/mlflow_trace.py @@ -5,10 +5,12 @@ from datetime import datetime, timedelta from typing import Any, cast import mlflow +from graphon.enums import BuiltinNodeTypes from mlflow.entities import Document, Span, SpanEvent, SpanStatusCode, SpanType from mlflow.tracing.constant import SpanAttributeKey, TokenUsageKey, TraceMetadataKey from mlflow.tracing.fluent import start_span_no_context, update_current_trace from mlflow.tracing.provider import detach_span_from_context, set_span_in_context +from sqlalchemy import select from core.ops.base_trace_instance import BaseTraceInstance from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig @@ -23,7 +25,6 @@ from core.ops.entities.trace_entity import ( TraceTaskName, WorkflowTraceInfo, ) -from dify_graph.enums import BuiltinNodeTypes from extensions.ext_database import db from models import EndUser from models.workflow import WorkflowNodeExecutionModel @@ -320,7 +321,7 @@ class MLflowDataTrace(BaseTraceInstance): def _get_message_user_id(self, metadata: dict) -> str | None: if (end_user_id := metadata.get("from_end_user_id")) and ( - end_user_data := db.session.query(EndUser).where(EndUser.id == end_user_id).first() + end_user_data := db.session.get(EndUser, end_user_id) ): return end_user_data.session_id @@ -447,25 +448,11 @@ class MLflowDataTrace(BaseTraceInstance): def _get_workflow_nodes(self, workflow_run_id: str): """Helper method to get workflow nodes""" - workflow_nodes = ( - db.session.query( - WorkflowNodeExecutionModel.id, - WorkflowNodeExecutionModel.tenant_id, - WorkflowNodeExecutionModel.app_id, - WorkflowNodeExecutionModel.title, - WorkflowNodeExecutionModel.node_type, - WorkflowNodeExecutionModel.status, - WorkflowNodeExecutionModel.inputs, - WorkflowNodeExecutionModel.outputs, - WorkflowNodeExecutionModel.created_at, - WorkflowNodeExecutionModel.elapsed_time, - WorkflowNodeExecutionModel.process_data, - WorkflowNodeExecutionModel.execution_metadata, - ) - .filter(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) + workflow_nodes = db.session.scalars( + select(WorkflowNodeExecutionModel) + .where(WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id) .order_by(WorkflowNodeExecutionModel.created_at) - .all() - ) + ).all() return workflow_nodes def _get_node_span_type(self, node_type: str) -> str: diff --git a/api/core/ops/opik_trace/opik_trace.py b/api/core/ops/opik_trace/opik_trace.py index fb72bc2381..2215bdeb33 100644 --- a/api/core/ops/opik_trace/opik_trace.py +++ b/api/core/ops/opik_trace/opik_trace.py @@ -5,6 +5,7 @@ import uuid from datetime import datetime, timedelta from typing import cast +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opik import Opik, Trace from opik.id_helpers import uuid4_to_uuid7 from sqlalchemy.orm import sessionmaker @@ -23,7 +24,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -176,8 +176,8 @@ class OpikDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) for node_execution in workflow_node_executions: @@ -288,9 +288,7 @@ class OpikDataTrace(BaseTraceInstance): metadata["file_list"] = file_list if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id metadata["end_user_id"] = end_user_id diff --git a/api/core/ops/ops_trace_manager.py b/api/core/ops/ops_trace_manager.py index 9ac753240b..9c36d57c6f 100644 --- a/api/core/ops/ops_trace_manager.py +++ b/api/core/ops/ops_trace_manager.py @@ -15,34 +15,179 @@ from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token -from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum +from core.ops.entities.config_entity import ( + OPS_FILE_PATH, + TracingProviderEnum, +) from core.ops.entities.trace_entity import ( DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, GenerateNameTraceInfo, MessageTraceInfo, ModerationTraceInfo, + PromptGenerationTraceInfo, SuggestedQuestionTraceInfo, TaskData, ToolTraceInfo, TraceTaskName, + WorkflowNodeTraceInfo, WorkflowTraceInfo, ) from core.ops.utils import get_message_data +from extensions.ext_database import db from extensions.ext_storage import storage -from models.engine import db +from models.account import Tenant +from models.dataset import Dataset from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig +from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType +from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider from models.workflow import WorkflowAppLog from tasks.ops_trace_task import process_trace_tasks if TYPE_CHECKING: - from dify_graph.entities import WorkflowExecution + from graphon.entities import WorkflowExecution logger = logging.getLogger(__name__) +def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]: + """Return (app_name, workspace_name) for the given IDs. Falls back to empty strings.""" + app_name = "" + workspace_name = "" + if not app_id and not tenant_id: + return app_name, workspace_name + with Session(db.engine) as session: + if app_id: + name = session.scalar(select(App.name).where(App.id == app_id)) + if name: + app_name = name + if tenant_id: + name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id)) + if name: + workspace_name = name + return app_name, workspace_name + + +_PROVIDER_TYPE_TO_MODEL: dict[str, type] = { + "builtin": BuiltinToolProvider, + "plugin": BuiltinToolProvider, + "api": ApiToolProvider, + "workflow": WorkflowToolProvider, + "mcp": MCPToolProvider, +} + + +def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str: + if not credential_id: + return "" + model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "") + if not model_cls: + return "" + with Session(db.engine) as session: + name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined] + return str(name) if name else "" + + +def _lookup_llm_credential_info( + tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm" +) -> tuple[str | None, str]: + """ + Lookup LLM credential ID and name for the given provider and model. + Returns (credential_id, credential_name). + + Handles async timing issues gracefully - if credential is deleted between lookups, + returns the ID but empty name rather than failing. + """ + if not tenant_id or not provider: + return None, "" + + try: + with Session(db.engine) as session: + # Try to find provider-level or model-level configuration + provider_record = session.scalar( + select(Provider).where( + Provider.tenant_id == tenant_id, + Provider.provider_name == provider, + Provider.provider_type == ProviderType.CUSTOM, + ) + ) + + if not provider_record: + return None, "" + + # Check if there's a model-specific config + credential_id = None + credential_name = "" + is_model_level = False + + if model: + # Try model-level first + model_record = session.scalar( + select(ProviderModel).where( + ProviderModel.tenant_id == tenant_id, + ProviderModel.provider_name == provider, + ProviderModel.model_name == model, + ProviderModel.model_type == model_type, + ) + ) + + if model_record and model_record.credential_id: + credential_id = model_record.credential_id + is_model_level = True + + if not credential_id and provider_record.credential_id: + # Fall back to provider-level credential + credential_id = provider_record.credential_id + is_model_level = False + + # Lookup credential_name if we have credential_id + if credential_id: + try: + if is_model_level: + # Query ProviderModelCredential + cred_name = session.scalar( + select(ProviderModelCredential.credential_name).where( + ProviderModelCredential.id == credential_id + ) + ) + else: + # Query ProviderCredential + cred_name = session.scalar( + select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id) + ) + + if cred_name: + credential_name = str(cred_name) + except Exception as e: + # Credential might have been deleted between lookups (async timing) + # Return ID but empty name rather than failing + logger.warning( + "Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s", + credential_id, + provider, + model, + str(e), + exc_info=True, + ) + + return credential_id, credential_name + except Exception as e: + # Database query failed or other unexpected error + # Return empty rather than propagating error to telemetry emission + logger.warning( + "Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s", + tenant_id, + provider, + model, + str(e), + exc_info=True, + ) + return None, "" + + class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): - def __getitem__(self, key: str) -> dict[str, Any]: - match key: + def __getitem__(self, provider: str) -> dict[str, Any]: + match provider: case TracingProviderEnum.LANGFUSE: from core.ops.entities.config_entity import LangfuseConfig from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace @@ -149,7 +294,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]): } case _: - raise KeyError(f"Unsupported tracing provider: {key}") + raise KeyError(f"Unsupported tracing provider: {provider}") provider_config_map = OpsTraceProviderConfigMap() @@ -275,10 +420,10 @@ class OpsTraceManager: :param tracing_provider: tracing provider :return: """ - trace_config_data: TraceAppConfig | None = ( - db.session.query(TraceAppConfig) + trace_config_data: TraceAppConfig | None = db.session.scalar( + select(TraceAppConfig) .where(TraceAppConfig.app_id == app_id, TraceAppConfig.tracing_provider == tracing_provider) - .first() + .limit(1) ) if not trace_config_data: @@ -314,7 +459,11 @@ class OpsTraceManager: if app_id is None: return None - app: App | None = db.session.query(App).where(App.id == app_id).first() + # Handle storage_id format (tenant-{uuid}) - not a real app_id + if isinstance(app_id, str) and app_id.startswith("tenant-"): + return None + + app = db.session.get(App, app_id) if app is None: return None @@ -388,7 +537,7 @@ class OpsTraceManager: except KeyError: raise ValueError(f"Invalid tracing provider: {tracing_provider}") - app_config: App | None = db.session.query(App).where(App.id == app_id).first() + app_config: App | None = db.session.get(App, app_id) if not app_config: raise ValueError("App not found") app_config.tracing = json.dumps( @@ -406,7 +555,7 @@ class OpsTraceManager: :param app_id: app id :return: """ - app: App | None = db.session.query(App).where(App.id == app_id).first() + app: App | None = db.session.get(App, app_id) if not app: raise ValueError("App not found") if not app.tracing: @@ -466,8 +615,6 @@ class TraceTask: @classmethod def _get_workflow_run_repo(cls): - from repositories.factory import DifyAPIRepositoryFactory - if cls._workflow_run_repo is None: with cls._repo_lock: if cls._workflow_run_repo is None: @@ -478,6 +625,77 @@ class TraceTask: cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker) return cls._workflow_run_repo + @classmethod + def _calculate_workflow_token_split( + cls, session: "Session", workflow_run_id: str, tenant_id: str + ) -> tuple[int, int]: + """Sum prompt/completion tokens across all node executions for a workflow run. + + Reads from the ``outputs`` column (where LLM nodes store ``usage.prompt_tokens`` + and ``usage.completion_tokens``) rather than ``execution_metadata``, which only + carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading + large JSON blobs unnecessarily. + """ + import json + + from models.workflow import WorkflowNodeExecutionModel + + rows = ( + session.execute( + select(WorkflowNodeExecutionModel.outputs).where( + WorkflowNodeExecutionModel.tenant_id == tenant_id, + WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id, + ) + ) + .scalars() + .all() + ) + + total_prompt = 0 + total_completion = 0 + + for raw in rows: + if not raw: + continue + try: + outputs = json.loads(raw) if isinstance(raw, str) else raw + except (ValueError, TypeError): + continue + if not isinstance(outputs, dict): + continue + usage = outputs.get("usage") + if not isinstance(usage, dict): + continue + prompt = usage.get("prompt_tokens") + if isinstance(prompt, (int, float)): + total_prompt += int(prompt) + completion = usage.get("completion_tokens") + if isinstance(completion, (int, float)): + total_completion += int(completion) + + return (total_prompt, total_completion) + + @classmethod + def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str: + """Extract user ID from metadata, prioritizing end_user over account. + + Returns the actual user ID (end_user or account) who invoked the workflow, + regardless of invoke_from context. + """ + # Priority 1: End user (external users via API/WebApp) + if user_id := metadata.get("from_end_user_id"): + return f"end_user:{user_id}" + + # Priority 2: Account user (internal users via console/debugger) + if user_id := metadata.get("from_account_id"): + return f"account:{user_id}" + + # Priority 3: User (internal users via console/debugger) + if user_id := metadata.get("user_id"): + return f"user:{user_id}" + + return "anonymous" + def __init__( self, trace_type: Any, @@ -491,6 +709,7 @@ class TraceTask: self.trace_type = trace_type self.message_id = message_id self.workflow_run_id = workflow_execution.id_ if workflow_execution else None + self.workflow_total_tokens: int | None = workflow_execution.total_tokens if workflow_execution else None self.conversation_id = conversation_id self.user_id = user_id self.timer = timer @@ -498,6 +717,8 @@ class TraceTask: self.app_id = None self.trace_id = None self.kwargs = kwargs + if user_id is not None and "user_id" not in self.kwargs: + self.kwargs["user_id"] = user_id external_trace_id = kwargs.get("external_trace_id") if external_trace_id: self.trace_id = external_trace_id @@ -509,9 +730,12 @@ class TraceTask: preprocess_map = { TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs), TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace( - workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id + workflow_run_id=self.workflow_run_id, + conversation_id=self.conversation_id, + user_id=self.user_id, + total_tokens_override=self.workflow_total_tokens, ), - TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id), + TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs), TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace( message_id=self.message_id, timer=self.timer, **self.kwargs ), @@ -527,6 +751,9 @@ class TraceTask: TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace( conversation_id=self.conversation_id, timer=self.timer, **self.kwargs ), + TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs), + TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs), + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs), } return preprocess_map.get(self.trace_type, lambda: None)() @@ -541,6 +768,7 @@ class TraceTask: workflow_run_id: str | None, conversation_id: str | None, user_id: str | None, + total_tokens_override: int | None = None, ): if not workflow_run_id: return {} @@ -560,7 +788,7 @@ class TraceTask: workflow_run_version = workflow_run.version error = workflow_run.error or "" - total_tokens = workflow_run.total_tokens + total_tokens = total_tokens_override if total_tokens_override is not None else workflow_run.total_tokens file_list = workflow_run_inputs.get("sys.file") or [] query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or "" @@ -581,8 +809,18 @@ class TraceTask: Message.workflow_run_id == workflow_run_id, ) message_id = session.scalar(message_data_stmt) + prompt_tokens, completion_tokens = self._calculate_workflow_token_split( + session, workflow_run_id=workflow_run_id, tenant_id=tenant_id + ) - metadata = { + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + metadata: dict[str, Any] = { "workflow_id": workflow_id, "conversation_id": conversation_id, "workflow_run_id": workflow_run_id, @@ -595,8 +833,14 @@ class TraceTask: "triggered_from": workflow_run.triggered_from, "user_id": user_id, "app_id": workflow_run.app_id, + "app_name": app_name, + "workspace_name": workspace_name, } + parent_trace_context = self.kwargs.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + workflow_trace_info = WorkflowTraceInfo( trace_id=self.trace_id, workflow_data=workflow_run.to_dict(), @@ -611,6 +855,8 @@ class TraceTask: workflow_run_version=workflow_run_version, error=error, total_tokens=total_tokens, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, file_list=file_list, query=query, metadata=metadata, @@ -618,10 +864,11 @@ class TraceTask: message_id=message_id, start_time=workflow_run.created_at, end_time=workflow_run.finished_at, + invoked_by=self._get_user_id_from_metadata(metadata), ) return workflow_trace_info - def message_trace(self, message_id: str | None): + def message_trace(self, message_id: str | None, **kwargs): if not message_id: return {} message_data = get_message_data(message_id) @@ -636,7 +883,7 @@ class TraceTask: inputs = message_data.message # get message file data - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) file_list = [] if message_file_data and message_file_data.url is not None: file_url = f"{self.file_base_url}/{message_file_data.url}" if message_file_data else "" @@ -644,6 +891,19 @@ class TraceTask: streaming_metrics = self._extract_streaming_metrics(message_data) + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + metadata = { "conversation_id": message_data.conversation_id, "ls_provider": message_data.model_provider, @@ -655,7 +915,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, "message_id": message_id, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id message_tokens = message_data.message_tokens @@ -672,7 +939,9 @@ class TraceTask: outputs=message_data.answer, file_list=file_list, start_time=created_at, - end_time=created_at + timedelta(seconds=message_data.provider_response_latency), + end_time=message_data.updated_at + if message_data.updated_at and message_data.updated_at > created_at + else created_at + timedelta(seconds=message_data.provider_response_latency), metadata=metadata, message_file_data=message_file_data, conversation_mode=conversation_mode, @@ -697,12 +966,14 @@ class TraceTask: "preset_response": moderation_result.preset_response, "query": moderation_result.query, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -738,12 +1009,14 @@ class TraceTask: "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id # get workflow_app_log_id workflow_app_log_id = None if message_data.workflow_run_id: - workflow_app_log_data = ( - db.session.query(WorkflowAppLog).filter_by(workflow_run_id=message_data.workflow_run_id).first() + workflow_app_log_data = db.session.scalar( + select(WorkflowAppLog).where(WorkflowAppLog.workflow_run_id == message_data.workflow_run_id).limit(1) ) workflow_app_log_id = str(workflow_app_log_data.id) if workflow_app_log_data else None @@ -777,6 +1050,52 @@ class TraceTask: if not message_data: return {} + tenant_id = "" + with Session(db.engine) as session: + tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id)) + if tid: + tenant_id = str(tid) + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id) + else: + app_name, workspace_name = "", "" + + doc_list = [doc.model_dump() for doc in documents] if documents else [] + dataset_ids: set[str] = set() + for doc in doc_list: + doc_meta = doc.get("metadata") or {} + did = doc_meta.get("dataset_id") + if did: + dataset_ids.add(did) + + embedding_models: dict[str, dict[str, str]] = {} + if dataset_ids: + with Session(db.engine) as session: + rows = session.execute( + select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where( + Dataset.id.in_(list(dataset_ids)) + ) + ).all() + for row in rows: + embedding_models[str(row[0])] = { + "embedding_model": row[1] or "", + "embedding_model_provider": row[2] or "", + } + + # Extract rerank model info from retrieval_model kwargs + rerank_model_provider = "" + rerank_model_name = "" + if "retrieval_model" in kwargs: + retrieval_model = kwargs["retrieval_model"] + if isinstance(retrieval_model, dict): + reranking_model = retrieval_model.get("reranking_model") + if isinstance(reranking_model, dict): + rerank_model_provider = reranking_model.get("reranking_provider_name", "") + rerank_model_name = reranking_model.get("reranking_model_name", "") + metadata = { "message_id": message_id, "ls_provider": message_data.model_provider, @@ -787,13 +1106,23 @@ class TraceTask: "agent_based": message_data.agent_based, "workflow_run_id": message_data.workflow_run_id, "from_source": message_data.from_source, + "tenant_id": tenant_id, + "app_id": message_data.app_id, + "user_id": message_data.from_end_user_id or message_data.from_account_id, + "app_name": app_name, + "workspace_name": workspace_name, + "embedding_models": embedding_models, + "rerank_model_provider": rerank_model_provider, + "rerank_model_name": rerank_model_name, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id dataset_retrieval_trace_info = DatasetRetrievalTraceInfo( trace_id=self.trace_id, message_id=message_id, inputs=message_data.query or message_data.inputs, - documents=[doc.model_dump() for doc in documents] if documents else [], + documents=doc_list, start_time=timer.get("start"), end_time=timer.get("end"), metadata=metadata, @@ -836,9 +1165,13 @@ class TraceTask: "error": error, "tool_parameters": tool_parameters, } + if message_data.workflow_run_id: + metadata["workflow_run_id"] = message_data.workflow_run_id + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id file_url = "" - message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first() + message_file_data = db.session.scalar(select(MessageFile).where(MessageFile.message_id == message_id).limit(1)) if message_file_data: message_file_id = message_file_data.id if message_file_data else None type = message_file_data.type @@ -890,6 +1223,8 @@ class TraceTask: "conversation_id": conversation_id, "tenant_id": tenant_id, } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id generate_name_trace_info = GenerateNameTraceInfo( trace_id=self.trace_id, @@ -904,6 +1239,182 @@ class TraceTask: return generate_name_trace_info + def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict: + tenant_id = kwargs.get("tenant_id", "") + user_id = kwargs.get("user_id", "") + app_id = kwargs.get("app_id") + operation_type = kwargs.get("operation_type", "") + instruction = kwargs.get("instruction", "") + generated_output = kwargs.get("generated_output", "") + + prompt_tokens = kwargs.get("prompt_tokens", 0) + completion_tokens = kwargs.get("completion_tokens", 0) + total_tokens = kwargs.get("total_tokens", 0) + + model_provider = kwargs.get("model_provider", "") + model_name = kwargs.get("model_name", "") + + latency = kwargs.get("latency", 0.0) + + timer = kwargs.get("timer") + start_time = timer.get("start") if timer else None + end_time = timer.get("end") if timer else None + + total_price = kwargs.get("total_price") + currency = kwargs.get("currency") + + error = kwargs.get("error") + + app_name = None + workspace_name = None + if app_id: + app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id) + + metadata = { + "tenant_id": tenant_id, + "user_id": user_id, + "app_id": app_id or "", + "app_name": app_name, + "workspace_name": workspace_name, + "operation_type": operation_type, + "model_provider": model_provider, + "model_name": model_name, + } + if node_execution_id := kwargs.get("node_execution_id"): + metadata["node_execution_id"] = node_execution_id + + return PromptGenerationTraceInfo( + trace_id=self.trace_id, + inputs=instruction, + outputs=generated_output, + start_time=start_time, + end_time=end_time, + metadata=metadata, + tenant_id=tenant_id, + user_id=user_id, + app_id=app_id, + operation_type=operation_type, + instruction=instruction, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + total_tokens=total_tokens, + model_provider=model_provider, + model_name=model_name, + latency=latency, + total_price=total_price, + currency=currency, + error=error, + ) + + def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict: + node_data: dict = kwargs.get("node_execution_data", {}) + if not node_data: + return {} + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + if is_enterprise_telemetry_enabled(): + app_name, workspace_name = _lookup_app_and_workspace_names( + node_data.get("app_id"), node_data.get("tenant_id") + ) + else: + app_name, workspace_name = "", "" + + # Try tool credential lookup first + credential_id = node_data.get("credential_id") + if is_enterprise_telemetry_enabled(): + credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type")) + # If no credential_id found (e.g., LLM nodes), try LLM credential lookup + if not credential_id: + llm_cred_id, llm_cred_name = _lookup_llm_credential_info( + tenant_id=node_data.get("tenant_id"), + provider=node_data.get("model_provider"), + model=node_data.get("model_name"), + model_type="llm", + ) + if llm_cred_id: + credential_id = llm_cred_id + credential_name = llm_cred_name + else: + credential_name = "" + metadata: dict[str, Any] = { + "tenant_id": node_data.get("tenant_id"), + "app_id": node_data.get("app_id"), + "app_name": app_name, + "workspace_name": workspace_name, + "user_id": node_data.get("user_id"), + "invoke_from": node_data.get("invoke_from"), + "credential_id": credential_id, + "credential_name": credential_name, + "dataset_ids": node_data.get("dataset_ids"), + "dataset_names": node_data.get("dataset_names"), + "plugin_name": node_data.get("plugin_name"), + } + + parent_trace_context = node_data.get("parent_trace_context") + if parent_trace_context: + metadata["parent_trace_context"] = parent_trace_context + + message_id: str | None = None + conversation_id = node_data.get("conversation_id") + workflow_execution_id = node_data.get("workflow_execution_id") + if conversation_id and workflow_execution_id and not parent_trace_context: + with Session(db.engine) as session: + msg_id = session.scalar( + select(Message.id).where( + Message.conversation_id == conversation_id, + Message.workflow_run_id == workflow_execution_id, + ) + ) + if msg_id: + message_id = str(msg_id) + metadata["message_id"] = message_id + if conversation_id: + metadata["conversation_id"] = conversation_id + + return WorkflowNodeTraceInfo( + trace_id=self.trace_id, + message_id=message_id, + start_time=node_data.get("created_at"), + end_time=node_data.get("finished_at"), + metadata=metadata, + workflow_id=node_data.get("workflow_id", ""), + workflow_run_id=node_data.get("workflow_execution_id", ""), + tenant_id=node_data.get("tenant_id", ""), + node_execution_id=node_data.get("node_execution_id", ""), + node_id=node_data.get("node_id", ""), + node_type=node_data.get("node_type", ""), + title=node_data.get("title", ""), + status=node_data.get("status", ""), + error=node_data.get("error"), + elapsed_time=node_data.get("elapsed_time", 0.0), + index=node_data.get("index", 0), + predecessor_node_id=node_data.get("predecessor_node_id"), + total_tokens=node_data.get("total_tokens", 0), + total_price=node_data.get("total_price", 0.0), + currency=node_data.get("currency"), + model_provider=node_data.get("model_provider"), + model_name=node_data.get("model_name"), + prompt_tokens=node_data.get("prompt_tokens"), + completion_tokens=node_data.get("completion_tokens"), + tool_name=node_data.get("tool_name"), + iteration_id=node_data.get("iteration_id"), + iteration_index=node_data.get("iteration_index"), + loop_id=node_data.get("loop_id"), + loop_index=node_data.get("loop_index"), + parallel_id=node_data.get("parallel_id"), + node_inputs=node_data.get("node_inputs"), + node_outputs=node_data.get("node_outputs"), + process_data=node_data.get("process_data"), + invoked_by=self._get_user_id_from_metadata(metadata), + ) + + def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict: + node_trace = self.node_execution_trace(**kwargs) + if not isinstance(node_trace, WorkflowNodeTraceInfo): + return node_trace + return DraftNodeExecutionTrace(**node_trace.model_dump()) + def _extract_streaming_metrics(self, message_data) -> dict: if not message_data.message_metadata: return {} @@ -937,13 +1448,17 @@ class TraceQueueManager: self.user_id = user_id self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id) self.flask_app = current_app._get_current_object() # type: ignore + + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() if trace_manager_timer is None: self.start_timer() def add_trace_task(self, trace_task: TraceTask): global trace_manager_timer, trace_manager_queue try: - if self.trace_instance: + if self._enterprise_telemetry_enabled or self.trace_instance: trace_task.app_id = self.app_id trace_manager_queue.put(trace_task) except Exception: @@ -979,20 +1494,27 @@ class TraceQueueManager: def send_to_celery(self, tasks: list[TraceTask]): with self.flask_app.app_context(): for task in tasks: - if task.app_id is None: - continue + storage_id = task.app_id + if storage_id is None: + tenant_id = task.kwargs.get("tenant_id") + if tenant_id: + storage_id = f"tenant-{tenant_id}" + else: + logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type) + continue + file_id = uuid4().hex trace_info = task.execute() task_data = TaskData( - app_id=task.app_id, + app_id=storage_id, trace_info_type=type(trace_info).__name__, trace_info=trace_info.model_dump() if trace_info else None, ) - file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json" + file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json" storage.save(file_path, task_data.model_dump_json().encode("utf-8")) file_info = { "file_id": file_id, - "app_id": task.app_id, + "app_id": storage_id, } process_trace_tasks.delay(file_info) # type: ignore diff --git a/api/core/ops/tencent_trace/span_builder.py b/api/core/ops/tencent_trace/span_builder.py index 0a6013e244..f79095d966 100644 --- a/api/core/ops/tencent_trace/span_builder.py +++ b/api/core/ops/tencent_trace/span_builder.py @@ -6,6 +6,8 @@ import json import logging from datetime import datetime +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import Status, StatusCode from core.ops.entities.trace_entity import ( @@ -41,11 +43,6 @@ from core.ops.tencent_trace.entities.semconv import ( from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.utils import TencentTraceUtils from core.rag.models.document import Document -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) logger = logging.getLogger(__name__) diff --git a/api/core/ops/tencent_trace/tencent_trace.py b/api/core/ops/tencent_trace/tencent_trace.py index 7e56b1effa..2bd6db22bf 100644 --- a/api/core/ops/tencent_trace/tencent_trace.py +++ b/api/core/ops/tencent_trace/tencent_trace.py @@ -4,6 +4,10 @@ Tencent APM tracing implementation with separated concerns import logging +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -24,10 +28,6 @@ from core.ops.tencent_trace.entities.tencent_trace_entity import SpanData from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.ops.tencent_trace.utils import TencentTraceUtils from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from dify_graph.nodes import BuiltinNodeTypes from extensions.ext_database import db from models import Account, App, TenantAccountJoin, WorkflowNodeExecutionTriggeredFrom @@ -256,7 +256,7 @@ class TencentDataTrace(BaseTraceInstance): triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - executions = repository.get_by_workflow_run(workflow_run_id=trace_info.workflow_run_id) + executions = repository.get_by_workflow_execution(workflow_execution_id=trace_info.workflow_run_id) return list(executions) except Exception: diff --git a/api/core/ops/weave_trace/weave_trace.py b/api/core/ops/weave_trace/weave_trace.py index 2a657b672c..8d9ba4694d 100644 --- a/api/core/ops/weave_trace/weave_trace.py +++ b/api/core/ops/weave_trace/weave_trace.py @@ -6,6 +6,7 @@ from typing import Any, cast import wandb import weave +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from sqlalchemy.orm import sessionmaker from weave.trace_server.trace_server_interface import ( CallEndReq, @@ -31,7 +32,6 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.repositories import DifyCoreRepositoryFactory -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from extensions.ext_database import db from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom @@ -161,8 +161,8 @@ class WeaveDataTrace(BaseTraceInstance): ) # Get all executions for this workflow run - workflow_node_executions = workflow_node_execution_repository.get_by_workflow_run( - workflow_run_id=trace_info.workflow_run_id + workflow_node_executions = workflow_node_execution_repository.get_by_workflow_execution( + workflow_execution_id=trace_info.workflow_run_id ) # rearrange workflow_node_executions by starting time @@ -245,9 +245,7 @@ class WeaveDataTrace(BaseTraceInstance): attributes["user_id"] = user_id if message_data.from_end_user_id: - end_user_data: EndUser | None = ( - db.session.query(EndUser).where(EndUser.id == message_data.from_end_user_id).first() - ) + end_user_data: EndUser | None = db.session.get(EndUser, message_data.from_end_user_id) if end_user_data is not None: end_user_id = end_user_data.session_id attributes["end_user_id"] = end_user_id diff --git a/api/core/plugin/backwards_invocation/app.py b/api/core/plugin/backwards_invocation/app.py index 60d08b26c9..be11d2223c 100644 --- a/api/core/plugin/backwards_invocation/app.py +++ b/api/core/plugin/backwards_invocation/app.py @@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation): get app """ try: - app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first() + app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1)) except Exception: raise ValueError("app not found") diff --git a/api/core/plugin/backwards_invocation/model.py b/api/core/plugin/backwards_invocation/model.py index 11c9191bac..c715b9171c 100644 --- a/api/core/plugin/backwards_invocation/model.py +++ b/api/core/plugin/backwards_invocation/model.py @@ -2,6 +2,20 @@ import tempfile from binascii import hexlify, unhexlify from collections.abc import Generator +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, +) +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.llm import deduct_llm_quota from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output from core.model_manager import ModelManager @@ -18,22 +32,26 @@ from core.plugin.entities.request import ( ) from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from dify_graph.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, -) -from dify_graph.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) from models.account import Tenant class PluginModelBackwardsInvocation(BaseBackwardsInvocation): + @staticmethod + def _get_bound_model_instance( + *, + tenant_id: str, + user_id: str | None, + provider: str, + model_type: ModelType, + model: str, + ): + return ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=provider, + model_type=model_type, + model=model, + ) + @classmethod def invoke_llm( cls, user_id: str, tenant: Tenant, payload: RequestInvokeLLM @@ -41,8 +59,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -55,7 +74,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, ) if isinstance(response, Generator): @@ -94,8 +112,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke llm with structured output """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -115,7 +134,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tools=payload.tools, stop=payload.stop, stream=True if payload.stream is None else payload.stream, - user=user_id, model_parameters=payload.completion_params, ) @@ -156,18 +174,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke text embedding """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_text_embedding( - texts=payload.texts, - user=user_id, - ) + response = model_instance.invoke_text_embedding(texts=payload.texts) return response @@ -176,8 +192,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke rerank """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -189,7 +206,6 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): docs=payload.docs, score_threshold=payload.score_threshold, top_n=payload.top_n, - user=user_id, ) return response @@ -199,20 +215,16 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke tts """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_tts( - content_text=payload.content_text, - tenant_id=tenant.id, - voice=payload.voice, - user=user_id, - ) + response = model_instance.invoke_tts(content_text=payload.content_text, voice=payload.voice) def handle() -> Generator[dict, None, None]: for chunk in response: @@ -225,8 +237,9 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke speech2text """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, @@ -238,10 +251,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): temp.flush() temp.seek(0) - response = model_instance.invoke_speech2text( - file=temp, - user=user_id, - ) + response = model_instance.invoke_speech2text(file=temp) return { "result": response, @@ -252,36 +262,38 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke moderation """ - model_instance = ModelManager().get_model_instance( + model_instance = cls._get_bound_model_instance( tenant_id=tenant.id, + user_id=user_id, provider=payload.provider, model_type=payload.model_type, model=payload.model, ) # invoke model - response = model_instance.invoke_moderation( - text=payload.text, - user=user_id, - ) + response = model_instance.invoke_moderation(text=payload.text) return { "result": response, } @classmethod - def get_system_model_max_tokens(cls, tenant_id: str) -> int: + def get_system_model_max_tokens(cls, tenant_id: str, user_id: str | None = None) -> int: """ get system model max tokens """ - return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id) + return ModelInvocationUtils.get_max_llm_context_tokens(tenant_id=tenant_id, user_id=user_id) @classmethod - def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def get_prompt_tokens(cls, tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ get prompt tokens """ - return ModelInvocationUtils.calculate_tokens(tenant_id=tenant_id, prompt_messages=prompt_messages) + return ModelInvocationUtils.calculate_tokens( + tenant_id=tenant_id, + prompt_messages=prompt_messages, + user_id=user_id, + ) @classmethod def invoke_system_model( @@ -299,6 +311,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): tool_type=ToolProviderType.PLUGIN, tool_name="plugin", prompt_messages=prompt_messages, + caller_user_id=user_id, ) @classmethod @@ -306,7 +319,7 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation): """ invoke summary """ - max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id) + max_tokens = cls.get_system_model_max_tokens(tenant_id=tenant.id, user_id=user_id) content = payload.text SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language @@ -325,6 +338,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=content)], + user_id=user_id, ) < max_tokens * 0.6 ): @@ -337,6 +351,7 @@ Here is the extra instruction you need to follow: SystemPromptMessage(content=SUMMARY_PROMPT.replace("{payload.instruction}", payload.instruction)), UserPromptMessage(content=content), ], + user_id=user_id, ) def summarize(content: str) -> str: @@ -394,6 +409,7 @@ Here is the extra instruction you need to follow: cls.get_prompt_tokens( tenant_id=tenant.id, prompt_messages=[UserPromptMessage(content=result)], + user_id=user_id, ) > max_tokens * 0.7 ): diff --git a/api/core/plugin/backwards_invocation/node.py b/api/core/plugin/backwards_invocation/node.py index d6aef93fc4..9478997494 100644 --- a/api/core/plugin/backwards_invocation/node.py +++ b/api/core/plugin/backwards_invocation/node.py @@ -1,19 +1,15 @@ -from core.plugin.backwards_invocation.base import BaseBackwardsInvocation -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig +from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ParameterExtractorNodeData, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ClassConfig, QuestionClassifierNodeData, ) -from dify_graph.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) + +from core.plugin.backwards_invocation.base import BaseBackwardsInvocation from services.workflow_service import WorkflowService @@ -24,7 +20,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): tenant_id: str, user_id: str, parameters: list[ParameterConfig], - model_config: ParameterExtractorModelConfig, + model_config: LLMModelConfig, instruction: str, query: str, ): @@ -74,7 +70,7 @@ class PluginNodeBackwardsInvocation(BaseBackwardsInvocation): cls, tenant_id: str, user_id: str, - model_config: QuestionClassifierModelConfig, + model_config: LLMModelConfig, classes: list[ClassConfig], instruction: str, query: str, diff --git a/api/core/plugin/backwards_invocation/tool.py b/api/core/plugin/backwards_invocation/tool.py index c2d1574e67..0585494269 100644 --- a/api/core/plugin/backwards_invocation/tool.py +++ b/api/core/plugin/backwards_invocation/tool.py @@ -31,7 +31,13 @@ class PluginToolBackwardsInvocation(BaseBackwardsInvocation): # get tool runtime try: tool_runtime = ToolManager.get_tool_runtime_from_plugin( - tool_type, tenant_id, provider, tool_name, tool_parameters, credential_id + tool_type, + tenant_id, + provider, + tool_name, + tool_parameters, + user_id=user_id, + credential_id=credential_id, ) response = ToolEngine.generic_invoke( tool_runtime, tool_parameters, user_id, DifyWorkflowCallbackHandler(), workflow_call_depth=1 diff --git a/api/core/plugin/entities/marketplace.py b/api/core/plugin/entities/marketplace.py index 81e1e12c5f..2177e8af90 100644 --- a/api/core/plugin/entities/marketplace.py +++ b/api/core/plugin/entities/marketplace.py @@ -1,10 +1,10 @@ +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, Field, computed_field, model_validator from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.plugin.entities.plugin import PluginResourceRequirements from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class MarketplacePluginDeclaration(BaseModel): diff --git a/api/core/plugin/entities/plugin.py b/api/core/plugin/entities/plugin.py index 7a3780f7de..b095b4998d 100644 --- a/api/core/plugin/entities/plugin.py +++ b/api/core/plugin/entities/plugin.py @@ -3,6 +3,7 @@ from collections.abc import Mapping from enum import StrEnum, auto from typing import Any +from graphon.model_runtime.entities.provider_entities import ProviderEntity from packaging.version import InvalidVersion, Version from pydantic import BaseModel, Field, field_validator, model_validator @@ -13,7 +14,6 @@ from core.plugin.entities.endpoint import EndpointProviderDeclaration from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntity from core.trigger.entities.entities import TriggerProviderEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity class PluginInstallationSource(StrEnum): diff --git a/api/core/plugin/entities/plugin_daemon.py b/api/core/plugin/entities/plugin_daemon.py index 416e0f6b4d..94263ec44e 100644 --- a/api/core/plugin/entities/plugin_daemon.py +++ b/api/core/plugin/entities/plugin_daemon.py @@ -6,6 +6,8 @@ from datetime import datetime from enum import StrEnum from typing import Any, Generic, TypeVar +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.provider_entities import ProviderEntity from pydantic import BaseModel, ConfigDict, Field from core.agent.plugin_entities import AgentProviderEntityWithPlugin @@ -16,8 +18,6 @@ from core.plugin.entities.plugin import PluginDeclaration, PluginEntity from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderEntityWithPlugin from core.trigger.entities.entities import TriggerProviderEntity -from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.provider_entities import ProviderEntity T = TypeVar("T", bound=(BaseModel | dict | list | bool | str)) diff --git a/api/core/plugin/entities/request.py b/api/core/plugin/entities/request.py index c15e9b0385..059f3fa9be 100644 --- a/api/core/plugin/entities/request.py +++ b/api/core/plugin/entities/request.py @@ -4,11 +4,7 @@ from collections.abc import Mapping from typing import Any, Literal from flask import Response -from pydantic import BaseModel, ConfigDict, Field, field_validator - -from core.entities.provider_entities import BasicProviderConfig -from core.plugin.utils.http_parser import deserialize_response -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, PromptMessage, PromptMessageRole, @@ -17,19 +13,18 @@ from dify_graph.model_runtime.entities.message_entities import ( ToolPromptMessage, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.nodes.parameter_extractor.entities import ( - ModelConfig as ParameterExtractorModelConfig, -) -from dify_graph.nodes.parameter_extractor.entities import ( +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.nodes.llm.entities import ModelConfig as LLMModelConfig +from graphon.nodes.parameter_extractor.entities import ( ParameterConfig, ) -from dify_graph.nodes.question_classifier.entities import ( +from graphon.nodes.question_classifier.entities import ( ClassConfig, ) -from dify_graph.nodes.question_classifier.entities import ( - ModelConfig as QuestionClassifierModelConfig, -) +from pydantic import BaseModel, ConfigDict, Field, field_validator + +from core.entities.provider_entities import BasicProviderConfig +from core.plugin.utils.http_parser import deserialize_response class InvokeCredentials(BaseModel): @@ -176,7 +171,7 @@ class RequestInvokeParameterExtractorNode(BaseModel): """ parameters: list[ParameterConfig] - model: ParameterExtractorModelConfig + model: LLMModelConfig instruction: str query: str @@ -187,7 +182,7 @@ class RequestInvokeQuestionClassifierNode(BaseModel): """ query: str - model: QuestionClassifierModelConfig + model: LLMModelConfig classes: list[ClassConfig] instruction: str diff --git a/api/core/plugin/impl/base.py b/api/core/plugin/impl/base.py index 737d204105..2d0ab3fcd7 100644 --- a/api/core/plugin/impl/base.py +++ b/api/core/plugin/impl/base.py @@ -5,6 +5,14 @@ from collections.abc import Callable, Generator from typing import Any, TypeVar, cast import httpx +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from yarl import URL @@ -13,6 +21,7 @@ from core.plugin.endpoint.exc import EndpointSetupFailedError from core.plugin.entities.plugin_daemon import PluginDaemonBasicResponse, PluginDaemonError, PluginDaemonInnerError from core.plugin.impl.exc import ( PluginDaemonBadRequestError, + PluginDaemonClientSideError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, PluginDaemonUnauthorizedError, @@ -27,14 +36,6 @@ from core.trigger.errors import ( TriggerPluginInvokeError, TriggerProviderCredentialValidationError, ) -from dify_graph.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError plugin_daemon_inner_api_baseurl = URL(str(dify_config.PLUGIN_DAEMON_URL)) _plugin_daemon_timeout_config = cast( @@ -235,7 +236,10 @@ class BasePluginClient: response.raise_for_status() except httpx.HTTPStatusError as e: logger.exception("Failed to request plugin daemon, status: %s, url: %s", e.response.status_code, path) - raise e + if e.response.status_code < 500: + raise PluginDaemonClientSideError(description=str(e)) + else: + raise PluginDaemonInternalServerError(description=str(e)) except Exception as e: msg = f"Failed to request plugin daemon, url: {path}" logger.exception("Failed to request plugin daemon, url: %s", path) diff --git a/api/core/plugin/impl/model.py b/api/core/plugin/impl/model.py index 49ee5d79cb..1e38c24717 100644 --- a/api/core/plugin/impl/model.py +++ b/api/core/plugin/impl/model.py @@ -1,6 +1,13 @@ import binascii from collections.abc import Generator, Sequence -from typing import IO +from typing import IO, Any + +from graphon.model_runtime.entities.llm_entities import LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult +from graphon.model_runtime.utils.encoders import jsonable_encoder from core.plugin.entities.plugin_daemon import ( PluginBasicBooleanResponse, @@ -13,15 +20,16 @@ from core.plugin.entities.plugin_daemon import ( PluginVoicesResponse, ) from core.plugin.impl.base import BasePluginClient -from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import AIModelEntity -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.utils.encoders import jsonable_encoder class PluginModelClient(BasePluginClient): + @staticmethod + def _dispatch_payload(*, user_id: str | None, data: dict[str, Any]) -> dict[str, Any]: + payload: dict[str, Any] = {"data": data} + if user_id is not None: + payload["user_id"] = user_id + return payload + def fetch_model_providers(self, tenant_id: str) -> Sequence[PluginModelProviderEntity]: """ Fetch model providers for the given tenant. @@ -37,7 +45,7 @@ class PluginModelClient(BasePluginClient): def get_model_schema( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -51,15 +59,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/schema", PluginModelSchemaEntity, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -72,7 +80,7 @@ class PluginModelClient(BasePluginClient): return None def validate_provider_credentials( - self, tenant_id: str, user_id: str, plugin_id: str, provider: str, credentials: dict + self, tenant_id: str, user_id: str | None, plugin_id: str, provider: str, credentials: dict ) -> bool: """ validate the credentials of the provider @@ -81,13 +89,13 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_provider_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -105,7 +113,7 @@ class PluginModelClient(BasePluginClient): def validate_model_credentials( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -119,15 +127,15 @@ class PluginModelClient(BasePluginClient): "POST", f"plugin/{tenant_id}/dispatch/model/validate_model_credentials", PluginBasicBooleanResponse, - data={ - "user_id": user_id, - "data": { + data=self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, "credentials": credentials, }, - }, + ), headers={ "X-Plugin-ID": plugin_id, "Content-Type": "application/json", @@ -145,7 +153,7 @@ class PluginModelClient(BasePluginClient): def invoke_llm( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -164,9 +172,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/invoke", type_=LLMResultChunk, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "llm", "model": model, @@ -177,7 +185,7 @@ class PluginModelClient(BasePluginClient): "stop": stop, "stream": stream, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -193,7 +201,7 @@ class PluginModelClient(BasePluginClient): def get_llm_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model_type: str, @@ -210,9 +218,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/llm/num_tokens", type_=PluginLLMNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": model_type, "model": model, @@ -220,7 +228,7 @@ class PluginModelClient(BasePluginClient): "prompt_messages": prompt_messages, "tools": tools, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -236,7 +244,7 @@ class PluginModelClient(BasePluginClient): def invoke_text_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -252,9 +260,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -262,7 +270,7 @@ class PluginModelClient(BasePluginClient): "texts": texts, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -278,7 +286,7 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_embedding( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -294,9 +302,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_embedding/invoke", type_=EmbeddingResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, @@ -304,7 +312,7 @@ class PluginModelClient(BasePluginClient): "documents": documents, "input_type": input_type, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -320,7 +328,7 @@ class PluginModelClient(BasePluginClient): def get_text_embedding_num_tokens( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -335,16 +343,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/text_embedding/num_tokens", type_=PluginTextEmbeddingNumTokensResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "text-embedding", "model": model, "credentials": credentials, "texts": texts, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -360,7 +368,7 @@ class PluginModelClient(BasePluginClient): def invoke_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -378,9 +386,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -390,7 +398,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -406,13 +414,13 @@ class PluginModelClient(BasePluginClient): def invoke_multimodal_rerank( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, credentials: dict, - query: dict, - docs: list[dict], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], score_threshold: float | None = None, top_n: int | None = None, ) -> RerankResult: @@ -424,9 +432,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/multimodal_rerank/invoke", type_=RerankResult, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "rerank", "model": model, @@ -436,7 +444,7 @@ class PluginModelClient(BasePluginClient): "score_threshold": score_threshold, "top_n": top_n, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -451,7 +459,7 @@ class PluginModelClient(BasePluginClient): def invoke_tts( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -467,9 +475,9 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, @@ -478,7 +486,7 @@ class PluginModelClient(BasePluginClient): "content_text": content_text, "voice": voice, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -496,7 +504,7 @@ class PluginModelClient(BasePluginClient): def get_tts_model_voices( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -511,16 +519,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/tts/model/voices", type_=PluginVoicesResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "tts", "model": model, "credentials": credentials, "language": language, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -540,7 +548,7 @@ class PluginModelClient(BasePluginClient): def invoke_speech_to_text( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -555,16 +563,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/speech2text/invoke", type_=PluginStringResultResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "speech2text", "model": model, "credentials": credentials, "file": binascii.hexlify(file.read()).decode(), }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, @@ -580,7 +588,7 @@ class PluginModelClient(BasePluginClient): def invoke_moderation( self, tenant_id: str, - user_id: str, + user_id: str | None, plugin_id: str, provider: str, model: str, @@ -595,16 +603,16 @@ class PluginModelClient(BasePluginClient): path=f"plugin/{tenant_id}/dispatch/moderation/invoke", type_=PluginBasicBooleanResponse, data=jsonable_encoder( - { - "user_id": user_id, - "data": { + self._dispatch_payload( + user_id=user_id, + data={ "provider": provider, "model_type": "moderation", "model": model, "credentials": credentials, "text": text, }, - } + ) ), headers={ "X-Plugin-ID": plugin_id, diff --git a/api/core/plugin/impl/model_runtime.py b/api/core/plugin/impl/model_runtime.py new file mode 100644 index 0000000000..22c846b6de --- /dev/null +++ b/api/core/plugin/impl/model_runtime.py @@ -0,0 +1,499 @@ +from __future__ import annotations + +import hashlib +import logging +from collections.abc import Generator, Iterable, Sequence +from threading import Lock +from typing import IO, Any, Union + +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.provider_entities import ProviderEntity +from graphon.model_runtime.entities.rerank_entities import MultimodalRerankInput, RerankResult +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingInputType, EmbeddingResult +from graphon.model_runtime.runtime import ModelRuntime +from pydantic import ValidationError +from redis import RedisError + +from configs import dify_config +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.asset import PluginAssetManager +from core.plugin.impl.model import PluginModelClient +from extensions.ext_redis import redis_client +from models.provider_ids import ModelProviderID + +logger = logging.getLogger(__name__) + +# `TS` means tenant scope +TENANT_SCOPE_SCHEMA_CACHE_USER_ID = "__DIFY_TS__" + + +class PluginModelRuntime(ModelRuntime): + """Plugin-backed runtime adapter bound to tenant context and optional caller scope.""" + + tenant_id: str + user_id: str | None + client: PluginModelClient + _provider_entities: tuple[ProviderEntity, ...] | None + _provider_entities_lock: Lock + + def __init__(self, tenant_id: str, user_id: str | None, client: PluginModelClient) -> None: + if client is None: + raise ValueError("client is required.") + self.tenant_id = tenant_id + self.user_id = user_id + self.client = client + self._provider_entities = None + self._provider_entities_lock = Lock() + + def fetch_model_providers(self) -> Sequence[ProviderEntity]: + if self._provider_entities is not None: + return self._provider_entities + + with self._provider_entities_lock: + if self._provider_entities is None: + self._provider_entities = tuple( + self._to_provider_entity(provider) for provider in self.client.fetch_model_providers(self.tenant_id) + ) + + return self._provider_entities + + def get_provider_icon(self, *, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: + provider_schema = self._get_provider_schema(provider) + + if icon_type.lower() == "icon_small": + if not provider_schema.icon_small: + raise ValueError(f"Provider {provider} does not have small icon.") + file_name = ( + provider_schema.icon_small.zh_Hans if lang.lower() == "zh_hans" else provider_schema.icon_small.en_US + ) + elif icon_type.lower() == "icon_small_dark": + if not provider_schema.icon_small_dark: + raise ValueError(f"Provider {provider} does not have small dark icon.") + file_name = ( + provider_schema.icon_small_dark.zh_Hans + if lang.lower() == "zh_hans" + else provider_schema.icon_small_dark.en_US + ) + else: + raise ValueError(f"Unsupported icon type: {icon_type}.") + + if not file_name: + raise ValueError(f"Provider {provider} does not have icon.") + + image_mime_types = { + "jpg": "image/jpeg", + "jpeg": "image/jpeg", + "png": "image/png", + "gif": "image/gif", + "bmp": "image/bmp", + "tiff": "image/tiff", + "tif": "image/tiff", + "webp": "image/webp", + "svg": "image/svg+xml", + "ico": "image/vnd.microsoft.icon", + "heif": "image/heif", + "heic": "image/heic", + } + + extension = file_name.split(".")[-1] + mime_type = image_mime_types.get(extension, "image/png") + return PluginAssetManager().fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type + + def validate_provider_credentials(self, *, provider: str, credentials: dict[str, Any]) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_provider_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + credentials=credentials, + ) + + def validate_model_credentials( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> None: + plugin_id, provider_name = self._split_provider(provider) + self.client.validate_model_credentials( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + def get_model_schema( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> AIModelEntity | None: + cache_key = self._get_schema_cache_key( + provider=provider, + model_type=model_type, + model=model, + credentials=credentials, + ) + + cached_schema_json = None + try: + cached_schema_json = redis_client.get(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to read plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + if cached_schema_json: + try: + return AIModelEntity.model_validate_json(cached_schema_json) + except ValidationError: + logger.warning("Failed to validate cached plugin model schema for model %s", model, exc_info=True) + try: + redis_client.delete(cache_key) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to delete invalid plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + plugin_id, provider_name = self._split_provider(provider) + schema = self.client.get_model_schema( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + ) + + if schema: + try: + redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) + except (RedisError, RuntimeError) as exc: + logger.warning( + "Failed to write plugin model schema cache for model %s: %s", + model, + str(exc), + exc_info=True, + ) + + return schema + + def invoke_llm( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + model_parameters: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: list[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_llm( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + model_parameters=model_parameters, + prompt_messages=list(prompt_messages), + tools=tools, + stop=list(stop) if stop else None, + stream=stream, + ) + + def get_llm_num_tokens( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + prompt_messages: Sequence[PromptMessage], + tools: Sequence[PromptMessageTool] | None, + ) -> int: + if not dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: + return 0 + + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_llm_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model_type=model_type.value, + model=model, + credentials=credentials, + prompt_messages=list(prompt_messages), + tools=list(tools) if tools else None, + ) + + def invoke_text_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_text_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + input_type=input_type, + ) + + def invoke_multimodal_embedding( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + documents: list[dict[str, Any]], + input_type: EmbeddingInputType, + ) -> EmbeddingResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_embedding( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + documents=documents, + input_type=input_type, + ) + + def get_text_embedding_num_tokens( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + texts: list[str], + ) -> list[int]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_text_embedding_num_tokens( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + texts=texts, + ) + + def invoke_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: str, + docs: list[str], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_multimodal_rerank( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + query: MultimodalRerankInput, + docs: list[MultimodalRerankInput], + score_threshold: float | None, + top_n: int | None, + ) -> RerankResult: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_multimodal_rerank( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + query=query, + docs=docs, + score_threshold=score_threshold, + top_n=top_n, + ) + + def invoke_tts( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + content_text: str, + voice: str, + ) -> Iterable[bytes]: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_tts( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + content_text=content_text, + voice=voice, + ) + + def get_tts_model_voices( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + language: str | None, + ) -> Any: + plugin_id, provider_name = self._split_provider(provider) + return self.client.get_tts_model_voices( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + language=language, + ) + + def invoke_speech_to_text( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + file: IO[bytes], + ) -> str: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_speech_to_text( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + file=file, + ) + + def invoke_moderation( + self, + *, + provider: str, + model: str, + credentials: dict[str, Any], + text: str, + ) -> bool: + plugin_id, provider_name = self._split_provider(provider) + return self.client.invoke_moderation( + tenant_id=self.tenant_id, + user_id=self.user_id, + plugin_id=plugin_id, + provider=provider_name, + model=model, + credentials=credentials, + text=text, + ) + + def _get_provider_short_name_alias(self, provider: PluginModelProviderEntity) -> str: + """ + Expose a bare provider alias only for the canonical provider mapping. + + Multiple plugins can publish the same short provider slug. If every + provider entity keeps that slug in ``provider_name``, callers that still + resolve by short name become order-dependent. Restrict the alias to the + provider selected by ``ModelProviderID`` so legacy short-name lookups + remain deterministic while the runtime surface stays canonical. + """ + try: + canonical_provider_id = ModelProviderID(provider.provider) + except ValueError: + return "" + + if canonical_provider_id.plugin_id != provider.plugin_id: + return "" + if canonical_provider_id.provider_name != provider.provider: + return "" + + return provider.provider + + def _to_provider_entity(self, provider: PluginModelProviderEntity) -> ProviderEntity: + declaration = provider.declaration.model_copy(deep=True) + declaration.provider = f"{provider.plugin_id}/{provider.provider}" + declaration.provider_name = self._get_provider_short_name_alias(provider) + return declaration + + def _get_provider_schema(self, provider: str) -> ProviderEntity: + providers = self.fetch_model_providers() + provider_entity = next((item for item in providers if item.provider == provider), None) + if provider_entity is None: + provider_entity = next((item for item in providers if provider == item.provider_name), None) + if provider_entity is None: + raise ValueError(f"Invalid provider: {provider}") + return provider_entity + + def _get_schema_cache_key( + self, + *, + provider: str, + model_type: ModelType, + model: str, + credentials: dict[str, Any], + ) -> str: + # The plugin daemon distinguishes ``None`` from an explicit empty-string + # caller id, so the cache must only collapse ``None`` into tenant scope. + cache_user_id = TENANT_SCOPE_SCHEMA_CACHE_USER_ID if self.user_id is None else self.user_id + cache_key = f"{self.tenant_id}:{provider}:{model_type.value}:{model}:{cache_user_id}" + sorted_credentials = sorted(credentials.items()) if credentials else [] + if not sorted_credentials: + return cache_key + hashed_credentials = ":".join( + [hashlib.md5(f"{key}:{value}".encode()).hexdigest() for key, value in sorted_credentials] + ) + return f"{cache_key}:{hashed_credentials}" + + def _split_provider(self, provider: str) -> tuple[str, str]: + provider_id = ModelProviderID(provider) + return provider_id.plugin_id, provider_id.provider_name diff --git a/api/core/plugin/impl/model_runtime_factory.py b/api/core/plugin/impl/model_runtime_factory.py new file mode 100644 index 0000000000..4b29a6fc56 --- /dev/null +++ b/api/core/plugin/impl/model_runtime_factory.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + +from core.plugin.impl.model import PluginModelClient + +if TYPE_CHECKING: + from core.model_manager import ModelManager + from core.plugin.impl.model_runtime import PluginModelRuntime + from core.provider_manager import ProviderManager + + +class PluginModelAssembly: + """Compose request-scoped model views on top of a single plugin runtime.""" + + tenant_id: str + user_id: str | None + _model_runtime: PluginModelRuntime | None + _model_provider_factory: ModelProviderFactory | None + _provider_manager: ProviderManager | None + _model_manager: ModelManager | None + + def __init__(self, *, tenant_id: str, user_id: str | None = None) -> None: + self.tenant_id = tenant_id + self.user_id = user_id + self._model_runtime = None + self._model_provider_factory = None + self._provider_manager = None + self._model_manager = None + + @property + def model_runtime(self) -> PluginModelRuntime: + if self._model_runtime is None: + self._model_runtime = create_plugin_model_runtime(tenant_id=self.tenant_id, user_id=self.user_id) + return self._model_runtime + + @property + def model_provider_factory(self) -> ModelProviderFactory: + if self._model_provider_factory is None: + self._model_provider_factory = ModelProviderFactory(model_runtime=self.model_runtime) + return self._model_provider_factory + + @property + def provider_manager(self) -> ProviderManager: + if self._provider_manager is None: + from core.provider_manager import ProviderManager + + self._provider_manager = ProviderManager(model_runtime=self.model_runtime) + return self._provider_manager + + @property + def model_manager(self) -> ModelManager: + if self._model_manager is None: + from core.model_manager import ModelManager + + self._model_manager = ModelManager(provider_manager=self.provider_manager) + return self._model_manager + + +def create_plugin_model_assembly(*, tenant_id: str, user_id: str | None = None) -> PluginModelAssembly: + """Create a request-scoped assembly that shares one plugin runtime across model views.""" + return PluginModelAssembly(tenant_id=tenant_id, user_id=user_id) + + +def create_plugin_model_runtime(*, tenant_id: str, user_id: str | None = None) -> PluginModelRuntime: + """Create a plugin runtime with its client dependency fully composed.""" + from core.plugin.impl.model_runtime import PluginModelRuntime + + return PluginModelRuntime( + tenant_id=tenant_id, + user_id=user_id, + client=PluginModelClient(), + ) + + +def create_plugin_model_provider_factory(*, tenant_id: str, user_id: str | None = None) -> ModelProviderFactory: + """Create a tenant-bound model provider factory for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_provider_factory + + +def create_plugin_provider_manager(*, tenant_id: str, user_id: str | None = None) -> ProviderManager: + """Create a tenant-bound provider manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).provider_manager + + +def create_plugin_model_manager(*, tenant_id: str, user_id: str | None = None) -> ModelManager: + """Create a tenant-bound model manager for service flows.""" + return create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id).model_manager diff --git a/api/core/plugin/utils/converter.py b/api/core/plugin/utils/converter.py index 53bcd9e9c6..90350f8400 100644 --- a/api/core/plugin/utils/converter.py +++ b/api/core/plugin/utils/converter.py @@ -1,7 +1,8 @@ from typing import Any +from graphon.file import File + from core.tools.entities.tool_entities import ToolSelector -from dify_graph.file.models import File def convert_parameters_to_plugin_format(parameters: dict[str, Any]) -> dict[str, Any]: diff --git a/api/core/prompt/advanced_prompt_transform.py b/api/core/prompt/advanced_prompt_transform.py index ce9f7e64b2..19b5e9223a 100644 --- a/api/core/prompt/advanced_prompt_transform.py +++ b/api/core/prompt/advanced_prompt_transform.py @@ -1,6 +1,18 @@ from collections.abc import Mapping, Sequence from typing import cast +from graphon.file import File, file_manager +from graphon.model_runtime.entities import ( + AssistantPromptMessage, + PromptMessage, + PromptMessageRole, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes +from graphon.runtime import VariablePool + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter from core.memory.token_buffer_memory import TokenBufferMemory @@ -8,18 +20,6 @@ from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig from core.prompt.prompt_transform import PromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import file_manager -from dify_graph.file.models import File -from dify_graph.model_runtime.entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes -from dify_graph.runtime import VariablePool class AdvancedPromptTransform(PromptTransform): diff --git a/api/core/prompt/agent_history_prompt_transform.py b/api/core/prompt/agent_history_prompt_transform.py index d09a46bfde..9be70199b7 100644 --- a/api/core/prompt/agent_history_prompt_transform.py +++ b/api/core/prompt/agent_history_prompt_transform.py @@ -1,16 +1,17 @@ from typing import cast +from graphon.model_runtime.entities.message_entities import ( + PromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.prompt_transform import PromptTransform -from dify_graph.model_runtime.entities.message_entities import ( - PromptMessage, - SystemPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel class AgentHistoryPromptTransform(PromptTransform): diff --git a/api/core/prompt/entities/advanced_prompt_entities.py b/api/core/prompt/entities/advanced_prompt_entities.py index 667f5ef099..b98fd8c179 100644 --- a/api/core/prompt/entities/advanced_prompt_entities.py +++ b/api/core/prompt/entities/advanced_prompt_entities.py @@ -1,50 +1,7 @@ -from typing import Literal +from graphon.prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from pydantic import BaseModel - -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole - - -class ChatModelMessage(BaseModel): - """ - Chat Message. - """ - - text: str - role: PromptMessageRole - edition_type: Literal["basic", "jinja2"] | None = None - - -class CompletionModelPromptTemplate(BaseModel): - """ - Completion Model Prompt Template. - """ - - text: str - edition_type: Literal["basic", "jinja2"] | None = None - - -class MemoryConfig(BaseModel): - """ - Memory Config. - """ - - class RolePrefix(BaseModel): - """ - Role Prefix. - """ - - user: str - assistant: str - - class WindowConfig(BaseModel): - """ - Window Config. - """ - - enabled: bool - size: int | None = None - - role_prefix: RolePrefix | None = None - window: WindowConfig - query_prompt_template: str | None = None +__all__ = [ + "ChatModelMessage", + "CompletionModelPromptTemplate", + "MemoryConfig", +] diff --git a/api/core/prompt/prompt_transform.py b/api/core/prompt/prompt_transform.py index 951736831f..4539ae9f11 100644 --- a/api/core/prompt/prompt_transform.py +++ b/api/core/prompt/prompt_transform.py @@ -1,11 +1,12 @@ from typing import Any +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey class PromptTransform: diff --git a/api/core/prompt/simple_prompt_transform.py b/api/core/prompt/simple_prompt_transform.py index 10c44349ae..c706353ffe 100644 --- a/api/core/prompt/simple_prompt_transform.py +++ b/api/core/prompt/simple_prompt_transform.py @@ -4,14 +4,8 @@ from collections.abc import Mapping, Sequence from enum import StrEnum, auto from typing import TYPE_CHECKING, Any, cast -from core.app.app_config.entities import PromptTemplateEntity -from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.prompt.prompt_transform import PromptTransform -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import file_manager -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import file_manager +from graphon.model_runtime.entities.message_entities import ( ImagePromptMessageContent, PromptMessage, PromptMessageContentUnionTypes, @@ -19,10 +13,17 @@ from dify_graph.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from core.app.app_config.entities import PromptTemplateEntity +from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.prompt.prompt_transform import PromptTransform +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import AppMode if TYPE_CHECKING: - from dify_graph.file.models import File + from graphon.file import File class ModelMode(StrEnum): diff --git a/api/core/prompt/utils/prompt_message_util.py b/api/core/prompt/utils/prompt_message_util.py index 85a2201395..dbda749925 100644 --- a/api/core/prompt/utils/prompt_message_util.py +++ b/api/core/prompt/utils/prompt_message_util.py @@ -1,8 +1,7 @@ from collections.abc import Sequence from typing import Any, cast -from core.prompt.simple_prompt_transform import ModelMode -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, @@ -12,6 +11,8 @@ from dify_graph.model_runtime.entities import ( TextPromptMessageContent, ) +from core.prompt.simple_prompt_transform import ModelMode + class PromptMessageUtil: @staticmethod diff --git a/api/core/provider_manager.py b/api/core/provider_manager.py index 6d2be0ab7a..30933239f6 100644 --- a/api/core/provider_manager.py +++ b/api/core/provider_manager.py @@ -1,10 +1,20 @@ +from __future__ import annotations + import contextlib import json from collections import defaultdict from collections.abc import Sequence from json import JSONDecodeError -from typing import Any, cast +from typing import TYPE_CHECKING, Any, cast +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormType, + ProviderEntity, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session @@ -28,14 +38,6 @@ from core.entities.provider_entities import ( from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.helper.position_helper import is_filtered -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormType, - ProviderEntity, -) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions import ext_hosting_provider from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -53,15 +55,25 @@ from models.provider import ( from models.provider_ids import ModelProviderID from services.feature_service import FeatureService +if TYPE_CHECKING: + from graphon.model_runtime.runtime import ModelRuntime + class ProviderManager: """ - ProviderManager is a class that manages the model providers includes Hosting and Customize Model Providers. + ProviderManager manages tenant-scoped model provider configuration. + + The runtime adapter is injected by the composition layer so this class stays + focused on configuration assembly instead of constructing plugin runtimes. + Request-bound managers may carry caller identity in that runtime, and the + resulting ``ProviderConfiguration`` objects must reuse it for downstream + model-type and schema lookups. """ - def __init__(self): + def __init__(self, model_runtime: ModelRuntime): self.decoding_rsa_key = None self.decoding_cipher_rsa = None + self._model_runtime = model_runtime def get_configurations(self, tenant_id: str) -> ProviderConfigurations: """ @@ -127,7 +139,7 @@ class ProviderManager: ) # Get all provider entities - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_entities = model_provider_factory.get_providers() # Get All preferred provider types of the workspace @@ -255,6 +267,7 @@ class ProviderManager: custom_configuration=custom_configuration, model_settings=model_settings, ) + provider_configuration.bind_model_runtime(self._model_runtime) provider_configurations[str(provider_id_entity)] = provider_configuration @@ -321,7 +334,7 @@ class ProviderManager: if not default_model: return None - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = ModelProviderFactory(model_runtime=self._model_runtime) provider_schema = model_provider_factory.get_provider_schema(provider=default_model.provider_name) return DefaultModelEntity( @@ -392,7 +405,7 @@ class ProviderManager: # create default model default_model = TenantDefaultModel( tenant_id=tenant_id, - model_type=model_type.value, + model_type=model_type.to_origin_model_type(), provider_name=provider, model_name=model, ) diff --git a/api/core/rag/data_post_processor/data_post_processor.py b/api/core/rag/data_post_processor/data_post_processor.py index 33eb5f963a..b872ea8a8f 100644 --- a/api/core/rag/data_post_processor/data_post_processor.py +++ b/api/core/rag/data_post_processor/data_post_processor.py @@ -1,3 +1,5 @@ +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from typing_extensions import TypedDict from core.model_manager import ModelInstance, ModelManager @@ -8,8 +10,6 @@ from core.rag.rerank.entity.weight import KeywordSetting, VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_type import RerankMode -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError class RerankingModelDict(TypedDict): @@ -52,11 +52,10 @@ class DataPostProcessor: documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: if self.rerank_runner: - documents = self.rerank_runner.run(query, documents, score_threshold, top_n, user, query_type) + documents = self.rerank_runner.run(query, documents, score_threshold, top_n, query_type) if self.reorder_runner: documents = self.reorder_runner.run(documents) @@ -106,9 +105,9 @@ class DataPostProcessor: ) -> ModelInstance | None: if reranking_model: try: - model_manager = ModelManager() - reranking_provider_name = reranking_model["reranking_provider_name"] - reranking_model_name = reranking_model["reranking_model_name"] + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) + reranking_provider_name = reranking_model.get("reranking_provider_name") + reranking_model_name = reranking_model.get("reranking_model_name") if not reranking_provider_name or not reranking_model_name: return None rerank_model_instance = model_manager.get_model_instance( diff --git a/api/core/rag/datasource/retrieval_service.py b/api/core/rag/datasource/retrieval_service.py index 713319ab9d..cc6ec12c75 100644 --- a/api/core/rag/datasource/retrieval_service.py +++ b/api/core/rag/datasource/retrieval_service.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, NotRequired from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from sqlalchemy.orm import Session, load_only from typing_extensions import TypedDict @@ -23,7 +24,6 @@ from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.signature import sign_upload_file -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import ( ChildChunk, @@ -328,7 +328,7 @@ class RetrievalService: str(dataset.tenant_id), str(RerankMode.RERANKING_MODEL), reranking_model, None, False ) if dataset.is_multimodal: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) is_support_vision = model_manager.check_model_support_vision( tenant_id=dataset.tenant_id, provider=reranking_model["reranking_provider_name"], diff --git a/api/core/rag/datasource/vdb/vector_factory.py b/api/core/rag/datasource/vdb/vector_factory.py index cd12cd3fae..5a8d3a2f3f 100644 --- a/api/core/rag/datasource/vdb/vector_factory.py +++ b/api/core/rag/datasource/vdb/vector_factory.py @@ -4,6 +4,7 @@ import time from abc import ABC, abstractmethod from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import select from configs import dify_config @@ -14,7 +15,6 @@ from core.rag.embedding.cached_embedding import CacheEmbedding from core.rag.embedding.embedding_base import Embeddings from core.rag.index_processor.constant.doc_type import DocType from core.rag.models.document import Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from extensions.ext_redis import redis_client from extensions.ext_storage import storage @@ -303,7 +303,7 @@ class Vector: redis_client.delete(collection_exist_cache_key) def _get_embeddings(self) -> Embeddings: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, diff --git a/api/core/rag/docstore/dataset_docstore.py b/api/core/rag/docstore/dataset_docstore.py index cd27113245..e5b794f80d 100644 --- a/api/core/rag/docstore/dataset_docstore.py +++ b/api/core/rag/docstore/dataset_docstore.py @@ -3,12 +3,12 @@ from __future__ import annotations from collections.abc import Sequence from typing import Any +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func, select from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import ChildChunk, Dataset, DocumentSegment, SegmentAttachmentBinding @@ -73,7 +73,7 @@ class DatasetDocumentStore: max_position = 0 embedding_model = None if self._dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self._dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=self._dataset.tenant_id, provider=self._dataset.embedding_model_provider, diff --git a/api/core/rag/embedding/cached_embedding.py b/api/core/rag/embedding/cached_embedding.py index 6d1b65a055..3bdad00712 100644 --- a/api/core/rag/embedding/cached_embedding.py +++ b/api/core/rag/embedding/cached_embedding.py @@ -4,14 +4,14 @@ import pickle from typing import Any, cast import numpy as np +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from sqlalchemy.exc import IntegrityError from configs import dify_config from core.entities.embedding_type import EmbeddingInputType from core.model_manager import ModelInstance from core.rag.embedding.embedding_base import Embeddings -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from extensions.ext_database import db from extensions.ext_redis import redis_client from libs import helper @@ -21,9 +21,8 @@ logger = logging.getLogger(__name__) class CacheEmbedding(Embeddings): - def __init__(self, model_instance: ModelInstance, user: str | None = None): + def __init__(self, model_instance: ModelInstance): self._model_instance = model_instance - self._user = user def embed_documents(self, texts: list[str]) -> list[list[float]]: """Embed search docs in batches of 10.""" @@ -65,7 +64,7 @@ class CacheEmbedding(Embeddings): batch_texts = embedding_queue_texts[i : i + max_chunks] embedding_result = self._model_instance.invoke_text_embedding( - texts=batch_texts, user=self._user, input_type=EmbeddingInputType.DOCUMENT + texts=batch_texts, input_type=EmbeddingInputType.DOCUMENT ) for vector in embedding_result.embeddings: @@ -147,7 +146,6 @@ class CacheEmbedding(Embeddings): embedding_result = self._model_instance.invoke_multimodal_embedding( multimodel_documents=batch_multimodel_documents, - user=self._user, input_type=EmbeddingInputType.DOCUMENT, ) @@ -202,7 +200,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_text_embedding( - texts=[text], user=self._user, input_type=EmbeddingInputType.QUERY + texts=[text], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] @@ -245,7 +243,7 @@ class CacheEmbedding(Embeddings): return [float(x) for x in decoded_embedding] try: embedding_result = self._model_instance.invoke_multimodal_embedding( - multimodel_documents=[multimodel_document], user=self._user, input_type=EmbeddingInputType.QUERY + multimodel_documents=[multimodel_document], input_type=EmbeddingInputType.QUERY ) embedding_results = embedding_result.embeddings[0] diff --git a/api/core/rag/index_processor/processor/paragraph_index_processor.py b/api/core/rag/index_processor/processor/paragraph_index_processor.py index 726cc062f6..5c10ffbf2d 100644 --- a/api/core/rag/index_processor/processor/paragraph_index_processor.py +++ b/api/core/rag/index_processor/processor/paragraph_index_processor.py @@ -8,11 +8,23 @@ from typing import Any, cast logger = logging.getLogger(__name__) +from graphon.file import File, FileTransferMethod, FileType, file_manager +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + ImagePromptMessageContent, + PromptMessage, + PromptMessageContentUnionTypes, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType + +from core.app.file_access import DatabaseFileAccessController from core.app.llm import deduct_llm_quota from core.entities.knowledge_entities import PreviewDetail from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT from core.model_manager import ModelInstance -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager from core.rag.cleaner.clean_processor import CleanProcessor from core.rag.data_post_processor.data_post_processor import RerankingModelDict from core.rag.datasource.keyword.keyword_factory import Keyword @@ -27,16 +39,7 @@ from core.rag.index_processor.index_processor_base import BaseIndexProcessor, Su from core.rag.models.document import AttachmentDocument, Document, MultimodalGeneralStructureChunk from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.text_processing_utils import remove_leading_symbols -from dify_graph.file import File, FileTransferMethod, FileType, file_manager -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentUnionTypes, - TextPromptMessageContent, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType +from core.workflow.file_reference import build_file_reference from extensions.ext_database import db from factories.file_factory import build_from_mapping from libs import helper @@ -48,6 +51,8 @@ from services.account_service import AccountService from services.entities.knowledge_entities.knowledge_entities import Rule from services.summary_index_service import SummaryIndexService +_file_access_controller = DatabaseFileAccessController() + class ParagraphIndexProcessor(BaseIndexProcessor): def extract(self, extract_setting: ExtractSetting, **kwargs) -> list[Document]: @@ -410,7 +415,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): # If default prompt doesn't have {language} placeholder, use it as-is pass - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_model_bundle = provider_manager.get_provider_model_bundle( tenant_id, model_provider_name, ModelType.LLM ) @@ -555,6 +560,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor): file_obj = build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) file_objects.append(file_obj) except Exception as e: @@ -604,11 +610,12 @@ class ParagraphIndexProcessor(BaseIndexProcessor): filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, ) diff --git a/api/core/rag/models/document.py b/api/core/rag/models/document.py index dc3b771406..087736d0b0 100644 --- a/api/core/rag/models/document.py +++ b/api/core/rag/models/document.py @@ -2,10 +2,9 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from typing import Any +from graphon.file import File from pydantic import BaseModel, Field -from dify_graph.file import File - class ChildDocument(BaseModel): """Class for storing a piece of text and associated metadata.""" diff --git a/api/core/rag/rerank/rerank_base.py b/api/core/rag/rerank/rerank_base.py index 88acb75133..cc65262527 100644 --- a/api/core/rag/rerank/rerank_base.py +++ b/api/core/rag/rerank/rerank_base.py @@ -12,7 +12,6 @@ class BaseRerankRunner(ABC): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -21,7 +20,6 @@ class BaseRerankRunner(ABC): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ raise NotImplementedError diff --git a/api/core/rag/rerank/rerank_model.py b/api/core/rag/rerank/rerank_model.py index fcb14ffc52..211a9f5c5c 100644 --- a/api/core/rag/rerank/rerank_model.py +++ b/api/core/rag/rerank/rerank_model.py @@ -1,12 +1,13 @@ import base64 +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.rerank_entities import RerankResult + from core.model_manager import ModelInstance, ModelManager from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_base import BaseRerankRunner -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult from extensions.ext_database import db from extensions.ext_storage import storage from models.model import UploadFile @@ -22,7 +23,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -31,10 +31,11 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant( + tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id + ) is_support_vision = model_manager.check_model_support_vision( tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id, provider=self.rerank_model_instance.provider, @@ -43,12 +44,12 @@ class RerankModelRunner(BaseRerankRunner): ) if not is_support_vision: if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) else: return documents else: rerank_result, unique_documents = self.fetch_multimodal_rerank( - query, documents, score_threshold, top_n, user, query_type + query, documents, score_threshold, top_n, query_type ) rerank_documents = [] @@ -73,7 +74,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, ) -> tuple[RerankResult, list[Document]]: """ Fetch text rerank @@ -81,7 +81,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ docs = [] @@ -103,7 +102,7 @@ class RerankModelRunner(BaseRerankRunner): unique_documents.append(document) rerank_result = self.rerank_model_instance.invoke_rerank( - query=query, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=query, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents @@ -113,7 +112,6 @@ class RerankModelRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> tuple[RerankResult, list[Document]]: """ @@ -122,7 +120,6 @@ class RerankModelRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :param query_type: query type :return: rerank result """ @@ -168,7 +165,7 @@ class RerankModelRunner(BaseRerankRunner): documents = unique_documents if query_type == QueryType.TEXT_QUERY: - rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n, user) + rerank_result, unique_documents = self.fetch_text_rerank(query, documents, score_threshold, top_n) return rerank_result, unique_documents elif query_type == QueryType.IMAGE_QUERY: # Query file info within db.session context to ensure thread-safe access @@ -181,7 +178,7 @@ class RerankModelRunner(BaseRerankRunner): "content_type": DocType.IMAGE, } rerank_result = self.rerank_model_instance.invoke_multimodal_rerank( - query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n, user=user + query=file_query_dict, docs=docs, score_threshold=score_threshold, top_n=top_n ) return rerank_result, unique_documents else: diff --git a/api/core/rag/rerank/weight_rerank.py b/api/core/rag/rerank/weight_rerank.py index 7edd05d2d1..49123e13d0 100644 --- a/api/core/rag/rerank/weight_rerank.py +++ b/api/core/rag/rerank/weight_rerank.py @@ -2,6 +2,7 @@ import math from collections import Counter import numpy as np +from graphon.model_runtime.entities.model_entities import ModelType from core.model_manager import ModelManager from core.rag.datasource.keyword.jieba.jieba_keyword_table_handler import JiebaKeywordTableHandler @@ -11,7 +12,6 @@ from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.entity.weight import VectorSetting, Weights from core.rag.rerank.rerank_base import BaseRerankRunner -from dify_graph.model_runtime.entities.model_entities import ModelType class WeightRerankRunner(BaseRerankRunner): @@ -25,7 +25,6 @@ class WeightRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: """ @@ -34,7 +33,6 @@ class WeightRerankRunner(BaseRerankRunner): :param documents: documents for reranking :param score_threshold: score threshold :param top_n: top n - :param user: unique user id if needed :return: """ @@ -163,7 +161,7 @@ class WeightRerankRunner(BaseRerankRunner): """ query_vector_scores = [] - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=tenant_id, diff --git a/api/core/rag/retrieval/dataset_retrieval.py b/api/core/rag/retrieval/dataset_retrieval.py index 52061fd93d..1abea6639e 100644 --- a/api/core/rag/retrieval/dataset_retrieval.py +++ b/api/core/rag/retrieval/dataset_retrieval.py @@ -9,6 +9,11 @@ from collections.abc import Generator, Mapping from typing import Any, Union, cast from flask import Flask, current_app +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from sqlalchemy import and_, func, literal, or_, select from sqlalchemy.orm import Session @@ -56,6 +61,7 @@ from core.rag.retrieval.template_prompts import ( ) from core.tools.signature import sign_upload_file from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool +from core.workflow.file_reference import build_file_reference from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import ( KnowledgeRetrievalRequest, @@ -63,13 +69,9 @@ from core.workflow.nodes.knowledge_retrieval.retrieval import ( SourceChildChunk, SourceMetadata, ) -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from extensions.ext_database import db from extensions.ext_redis import redis_client +from libs.helper import parse_uuid_str_or_none from libs.json_in_md_parser import parse_and_check_json_markdown from models import UploadFile from models.dataset import ( @@ -160,7 +162,7 @@ class DatasetRetrieval: if request.model_provider is None or request.model_name is None or request.query is None: raise ValueError("model_provider, model_name, and query are required for single retrieval mode") - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=request.tenant_id, user_id=request.user_id) model_instance = model_manager.get_model_instance( tenant_id=request.tenant_id, model_type=ModelType.LLM, @@ -383,23 +385,27 @@ class DatasetRetrieval: return None, [] retrieve_config = config.retrieve_config - # check model is support tool calling - model_type_instance = model_config.provider_model_bundle.model_type_instance - model_type_instance = cast(LargeLanguageModel, model_type_instance) - - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=model_config.provider, model=model_config.model ) + model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - # get model schema + # Reuse the caller-bound model instance for both schema resolution and + # downstream planner/invoke calls so a single request never mixes + # tenant-scope and request-bound runtimes. model_schema = model_type_instance.get_model_schema( - model=model_config.model, credentials=model_config.credentials + model=model_instance.model_name, + credentials=model_instance.credentials, ) if not model_schema: return None, [] + model_config.provider_model_bundle = model_instance.provider_model_bundle + model_config.credentials = model_instance.credentials + model_config.model_schema = model_schema + planning_strategy = PlanningStrategy.REACT_ROUTER features = model_schema.features if features: @@ -517,11 +523,12 @@ class DatasetRetrieval: filename=upload_file.name, extension="." + upload_file.extension, mime_type=upload_file.mime_type, - tenant_id=segment.tenant_id, type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, remote_url=upload_file.source_url, - related_id=upload_file.id, + reference=build_file_reference( + record_id=str(upload_file.id), + ), size=upload_file.size, storage_key=upload_file.key, url=sign_upload_file(upload_file.id, upload_file.extension), @@ -986,6 +993,24 @@ class DatasetRetrieval: ) ) + @staticmethod + def _resolve_creator_user_role(user_from: str) -> CreatorUserRole | None: + """Map runtime user source values to dataset query audit roles. + + Workflow run context uses the hyphenated ``end-user`` value, while + ``DatasetQuery.created_by_role`` persists the underscore-based + ``CreatorUserRole.END_USER`` enum. Query logging is a side effect, so an + unsupported value should be skipped instead of aborting retrieval. + """ + normalized_user_from = str(user_from).strip().lower().replace("-", "_") + if normalized_user_from == CreatorUserRole.ACCOUNT.value: + return CreatorUserRole.ACCOUNT + if normalized_user_from == CreatorUserRole.END_USER.value: + return CreatorUserRole.END_USER + + logger.warning("Skipping dataset query audit log for unsupported user_from=%r", user_from) + return None + def _on_query( self, query: str | None, @@ -996,10 +1021,18 @@ class DatasetRetrieval: user_id: str, ): """ - Handle query. + Persist dataset query audit rows for retrieval requests. """ if not query and not attachment_ids: return + created_by = parse_uuid_str_or_none(user_id) + if created_by is None: + logger.debug( + "Skipping dataset query log: empty created_by user_id (user_from=%s, app_id=%s)", + user_from, + app_id, + ) + return dataset_queries = [] for dataset_id in dataset_ids: contents = [] @@ -1015,7 +1048,7 @@ class DatasetRetrieval: source=DatasetQuerySource.APP, source_app_id=app_id, created_by_role=CreatorUserRole(user_from), - created_by=user_id, + created_by=created_by, ) dataset_queries.append(dataset_query) if dataset_queries: @@ -1411,7 +1444,7 @@ class DatasetRetrieval: raise ValueError("metadata_model_config is required") # get metadata model instance # fetch model config - model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config) + model_instance, model_config = self._fetch_model_config(tenant_id, metadata_model_config, user_id=user_id) # fetch prompt messages prompt_messages, stop = self._get_prompt_template( @@ -1430,7 +1463,6 @@ class DatasetRetrieval: model_parameters=model_config.parameters, stop=stop, stream=True, - user=user_id, ), ) @@ -1533,7 +1565,7 @@ class DatasetRetrieval: return filters def _fetch_model_config( - self, tenant_id: str, model: ModelConfig + self, tenant_id: str, model: ModelConfig, user_id: str | None = None ) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]: """ Fetch model config @@ -1543,7 +1575,7 @@ class DatasetRetrieval: model_name = model.name provider_name = model.provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, provider=provider_name, model=model_name ) diff --git a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py index 23a2ac8386..dce7b6226c 100644 --- a/api/core/rag/retrieval/router/multi_dataset_function_call_router.py +++ b/api/core/rag/retrieval/router/multi_dataset_function_call_router.py @@ -1,9 +1,10 @@ from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessageTool, SystemPromptMessage, UserPromptMessage class FunctionCallMultiDatasetRouter: diff --git a/api/core/rag/retrieval/router/multi_dataset_react_route.py b/api/core/rag/retrieval/router/multi_dataset_react_route.py index ea110fa0a7..dd280cdf6a 100644 --- a/api/core/rag/retrieval/router/multi_dataset_react_route.py +++ b/api/core/rag/retrieval/router/multi_dataset_react_route.py @@ -1,15 +1,17 @@ from collections.abc import Generator, Sequence from typing import Union +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.app.llm import deduct_llm_quota -from core.model_manager import ModelInstance +from core.model_manager import ModelInstance, ModelManager from core.prompt.advanced_prompt_transform import AdvancedPromptTransform from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate from core.rag.retrieval.output_parser.react_output import ReactAction from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:""" @@ -119,6 +121,7 @@ class ReactMultiDatasetRouter: memory_config=None, memory=None, model_config=model_config, + model_instance=model_instance, ) result_text, usage = self._invoke_llm( completion_param=model_config.parameters, @@ -150,19 +153,24 @@ class ReactMultiDatasetRouter: :param stop: stop :return: """ - invoke_result: Generator[LLMResult, None, None] = model_instance.invoke_llm( + bound_model_instance = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id).get_model_instance( + tenant_id=tenant_id, + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) + invoke_result: Generator[LLMResult, None, None] = bound_model_instance.invoke_llm( prompt_messages=prompt_messages, model_parameters=completion_param, stop=stop, stream=True, - user=user_id, ) # handle invoke result text, usage = self._handle_invoke_result(invoke_result=invoke_result) # deduct quota - deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage) + deduct_llm_quota(tenant_id=tenant_id, model_instance=bound_model_instance, usage=usage) return text, usage diff --git a/api/core/rag/splitter/fixed_text_splitter.py b/api/core/rag/splitter/fixed_text_splitter.py index 7a00e8a886..e6aec4a3af 100644 --- a/api/core/rag/splitter/fixed_text_splitter.py +++ b/api/core/rag/splitter/fixed_text_splitter.py @@ -6,6 +6,8 @@ import codecs import re from typing import Any +from graphon.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer + from core.model_manager import ModelInstance from core.rag.splitter.text_splitter import ( TS, @@ -15,7 +17,6 @@ from core.rag.splitter.text_splitter import ( Set, Union, ) -from dify_graph.model_runtime.model_providers.__base.tokenizers.gpt2_tokenizer import GPT2Tokenizer class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter): diff --git a/api/core/repositories/__init__.py b/api/core/repositories/__init__.py index 6f2826f634..cfa9962ea8 100644 --- a/api/core/repositories/__init__.py +++ b/api/core/repositories/__init__.py @@ -4,7 +4,13 @@ from __future__ import annotations from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from .factory import DifyCoreRepositoryFactory, RepositoryImportError +from .factory import ( + DifyCoreRepositoryFactory, + OrderConfig, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository @@ -12,7 +18,10 @@ __all__ = [ "CeleryWorkflowExecutionRepository", "CeleryWorkflowNodeExecutionRepository", "DifyCoreRepositoryFactory", + "OrderConfig", "RepositoryImportError", "SQLAlchemyWorkflowExecutionRepository", "SQLAlchemyWorkflowNodeExecutionRepository", + "WorkflowExecutionRepository", + "WorkflowNodeExecutionRepository", ] diff --git a/api/core/repositories/celery_workflow_execution_repository.py b/api/core/repositories/celery_workflow_execution_repository.py index 57764574d7..465f43da73 100644 --- a/api/core/repositories/celery_workflow_execution_repository.py +++ b/api/core/repositories/celery_workflow_execution_repository.py @@ -8,11 +8,11 @@ providing improved performance by offloading database operations to background w import logging from typing import Union +from graphon.entities import WorkflowExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.repositories.factory import WorkflowExecutionRepository from libs.helper import extract_tenant_id from models import Account, CreatorUserRole, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/core/repositories/celery_workflow_node_execution_repository.py b/api/core/repositories/celery_workflow_node_execution_repository.py index 650cf79550..22ef44b3dc 100644 --- a/api/core/repositories/celery_workflow_node_execution_repository.py +++ b/api/core/repositories/celery_workflow_node_execution_repository.py @@ -9,11 +9,11 @@ import logging from collections.abc import Sequence from typing import Union +from graphon.entities import WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.repositories.workflow_node_execution_repository import ( +from core.repositories.factory import ( OrderConfig, WorkflowNodeExecutionRepository, ) @@ -148,24 +148,24 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): # For now, we'll re-raise the exception raise - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all WorkflowNodeExecution instances for a specific workflow run from cache. + Retrieve all workflow node executions for a workflow execution from cache. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results Returns: A sequence of WorkflowNodeExecution instances """ try: - # Get execution IDs for this workflow run from cache - execution_ids = self._workflow_execution_mapping.get(workflow_run_id, []) + # Get execution IDs for this workflow execution from cache + execution_ids = self._workflow_execution_mapping.get(workflow_execution_id, []) # Retrieve executions from cache result = [] @@ -182,9 +182,16 @@ class CeleryWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): for field_name in reversed(order_config.order_by): result.sort(key=lambda x: getattr(x, field_name, 0), reverse=reverse) - logger.debug("Retrieved %d workflow node executions for run %s from cache", len(result), workflow_run_id) + logger.debug( + "Retrieved %d workflow node executions for execution %s from cache", + len(result), + workflow_execution_id, + ) return result except Exception: - logger.exception("Failed to get workflow node executions for run %s from cache", workflow_run_id) + logger.exception( + "Failed to get workflow node executions for execution %s from cache", + workflow_execution_id, + ) return [] diff --git a/api/core/repositories/factory.py b/api/core/repositories/factory.py index dc9f8c96bf..ed6d44f434 100644 --- a/api/core/repositories/factory.py +++ b/api/core/repositories/factory.py @@ -5,20 +5,45 @@ This module provides a Django-like settings system for repository implementation allowing users to configure different repository backends through string paths. """ -from typing import Union +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Literal, Protocol, Union +from graphon.entities import WorkflowExecution, WorkflowNodeExecution from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowNodeExecutionTriggeredFrom +@dataclass +class OrderConfig: + """Configuration for ordering node execution instances.""" + + order_by: list[str] + order_direction: Literal["asc", "desc"] | None = None + + +class WorkflowExecutionRepository(Protocol): + def save(self, execution: WorkflowExecution): ... + + +class WorkflowNodeExecutionRepository(Protocol): + def save(self, execution: WorkflowNodeExecution): ... + + def save_execution_data(self, execution: WorkflowNodeExecution): ... + + def get_by_workflow_execution( + self, + workflow_execution_id: str, + order_config: OrderConfig | None = None, + ) -> Sequence[WorkflowNodeExecution]: ... + + class RepositoryImportError(Exception): """Raised when a repository implementation cannot be imported or instantiated.""" diff --git a/api/core/repositories/human_input_repository.py b/api/core/repositories/human_input_repository.py index 6607a87032..72d9394149 100644 --- a/api/core/repositories/human_input_repository.py +++ b/api/core/repositories/human_input_repository.py @@ -2,32 +2,22 @@ import dataclasses import json from collections.abc import Mapping, Sequence from datetime import datetime -from typing import Any +from typing import Any, Protocol +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import select from sqlalchemy.orm import Session, selectinload from core.db.session_factory import session_factory -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( + BoundRecipient, DeliveryChannelConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, - MemberRecipient, - WebAppDeliveryMethod, -) -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - FormNotFoundError, - HumanInputFormEntity, - HumanInputFormRecipientEntity, + InteractiveSurfaceDeliveryMethod, + is_human_input_webapp_enabled, ) from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -36,6 +26,7 @@ from models.human_input import ( BackstageRecipientPayload, ConsoleDeliveryPayload, ConsoleRecipientPayload, + DeliveryMethodType, EmailExternalRecipientPayload, EmailMemberRecipientPayload, HumanInputDelivery, @@ -58,6 +49,65 @@ class _WorkspaceMemberInfo: email: str +class FormNotFoundError(Exception): + pass + + +@dataclasses.dataclass +class FormCreateParams: + workflow_execution_id: str | None + node_id: str + form_config: HumanInputNodeData + rendered_content: str + delivery_methods: Sequence[DeliveryChannelConfig] + display_in_ui: bool + resolved_default_values: Mapping[str, Any] + form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME + + +class HumanInputFormRecipientEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def token(self) -> str: ... + + +class HumanInputFormEntity(Protocol): + @property + def id(self) -> str: ... + + @property + def submission_token(self) -> str | None: ... + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: ... + + @property + def rendered_content(self) -> str: ... + + @property + def selected_action_id(self) -> str | None: ... + + @property + def submitted_data(self) -> Mapping[str, Any] | None: ... + + @property + def submitted(self) -> bool: ... + + @property + def status(self) -> HumanInputFormStatus: ... + + @property + def expiration_time(self) -> datetime: ... + + +class HumanInputFormRepository(Protocol): + def get_form(self, node_id: str) -> HumanInputFormEntity | None: ... + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: ... + + class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity): def __init__(self, recipient_model: HumanInputFormRecipient): self._recipient_model = recipient_model @@ -77,7 +127,7 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]): self._form_model = form_model self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models] - self._web_app_recipient = next( + self._interactive_surface_recipient = next( ( recipient for recipient in recipient_models @@ -98,12 +148,12 @@ class _HumanInputFormEntityImpl(HumanInputFormEntity): return self._form_model.id @property - def web_app_token(self): + def submission_token(self) -> str | None: if self._console_recipient is not None: return self._console_recipient.access_token - if self._web_app_recipient is None: + if self._interactive_surface_recipient is None: return None - return self._web_app_recipient.access_token + return self._interactive_surface_recipient.access_token @property def recipients(self) -> list[HumanInputFormRecipientEntity]: @@ -201,8 +251,16 @@ class HumanInputFormRepositoryImpl: self, *, tenant_id: str, - ): + app_id: str | None = None, + workflow_execution_id: str | None = None, + invoke_source: str | None = None, + submission_actor_id: str | None = None, + ) -> None: self._tenant_id = tenant_id + self._app_id = app_id + self._workflow_execution_id = workflow_execution_id + self._invoke_source = invoke_source + self._submission_actor_id = submission_actor_id def _delivery_method_to_model( self, @@ -219,7 +277,7 @@ class HumanInputFormRepositoryImpl: channel_payload=delivery_method.model_dump_json(), ) recipients: list[HumanInputFormRecipient] = [] - if isinstance(delivery_method, WebAppDeliveryMethod): + if isinstance(delivery_method, InteractiveSurfaceDeliveryMethod): recipient_model = HumanInputFormRecipient( form_id=form_id, delivery_id=delivery_id, @@ -247,16 +305,16 @@ class HumanInputFormRepositoryImpl: delivery_id: str, recipients_config: EmailRecipients, ) -> list[HumanInputFormRecipient]: - member_user_ids = [ - recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient) + bound_reference_ids = [ + recipient.reference_id for recipient in recipients_config.items if isinstance(recipient, BoundRecipient) ] external_emails = [ recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient) ] - if recipients_config.whole_workspace: + if recipients_config.include_bound_group: members = self._query_all_workspace_members(session=session) else: - members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids) + members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=bound_reference_ids) return self._create_email_recipients_from_resolved( form_id=form_id, @@ -338,8 +396,33 @@ class HumanInputFormRepositoryImpl: rows = session.execute(stmt).all() return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows] + def _should_create_console_recipient( + self, + *, + form_config: HumanInputNodeData, + form_kind: HumanInputFormKind, + ) -> bool: + if form_kind != HumanInputFormKind.RUNTIME: + return False + if self._invoke_source == "debugger": + return True + if self._invoke_source == "explore": + return is_human_input_webapp_enabled(form_config) + return False + + def _should_create_backstage_recipient(self, *, form_kind: HumanInputFormKind) -> bool: + return form_kind == HumanInputFormKind.RUNTIME and ( + self._invoke_source is not None or self._submission_actor_id is not None + ) + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: form_config: HumanInputNodeData = params.form_config + app_id = self._app_id + if not app_id: + raise ValueError("app_id is required to create a human input form") + workflow_execution_id = params.workflow_execution_id or self._workflow_execution_id + if params.form_kind == HumanInputFormKind.RUNTIME and workflow_execution_id is None: + raise ValueError("workflow_execution_id is required for runtime human input forms") with session_factory.create_session() as session, session.begin(): # Generate unique form ID @@ -359,8 +442,8 @@ class HumanInputFormRepositoryImpl: form_model = HumanInputForm( id=form_id, tenant_id=self._tenant_id, - app_id=params.app_id, - workflow_run_id=params.workflow_execution_id, + app_id=app_id, + workflow_run_id=workflow_execution_id, form_kind=params.form_kind, node_id=params.node_id, form_definition=form_definition.model_dump_json(), @@ -379,7 +462,7 @@ class HumanInputFormRepositoryImpl: session.add(delivery_and_recipients.delivery) session.add_all(delivery_and_recipients.recipients) recipient_models.extend(delivery_and_recipients.recipients) - if params.console_recipient_required and not any( + if self._should_create_console_recipient(form_config=form_config, form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models ): console_delivery_id = str(uuidv7()) @@ -395,13 +478,13 @@ class HumanInputFormRepositoryImpl: delivery_id=console_delivery_id, recipient_type=RecipientType.CONSOLE, recipient_payload=ConsoleRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(console_delivery) session.add(console_recipient) recipient_models.append(console_recipient) - if params.backstage_recipient_required and not any( + if self._should_create_backstage_recipient(form_kind=params.form_kind) and not any( recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models ): backstage_delivery_id = str(uuidv7()) @@ -417,7 +500,7 @@ class HumanInputFormRepositoryImpl: delivery_id=backstage_delivery_id, recipient_type=RecipientType.BACKSTAGE, recipient_payload=BackstageRecipientPayload( - account_id=params.console_creator_account_id, + account_id=self._submission_actor_id, ).model_dump_json(), ) session.add(backstage_delivery) @@ -427,9 +510,12 @@ class HumanInputFormRepositoryImpl: return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + if self._workflow_execution_id is None: + raise ValueError("workflow_execution_id is required to load runtime human input forms") + form_query = select(HumanInputForm).where( - HumanInputForm.workflow_run_id == workflow_execution_id, + HumanInputForm.workflow_run_id == self._workflow_execution_id, HumanInputForm.node_id == node_id, HumanInputForm.tenant_id == self._tenant_id, ) diff --git a/api/core/repositories/sqlalchemy_workflow_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_execution_repository.py index 55e96515ac..85d20b675d 100644 --- a/api/core/repositories/sqlalchemy_workflow_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_execution_repository.py @@ -6,13 +6,13 @@ import json import logging from typing import Union +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import WorkflowExecutionRepository from libs.helper import extract_tenant_id from models import ( Account, diff --git a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py index 7373ebc7cc..a72bfa378b 100644 --- a/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py +++ b/api/core/repositories/sqlalchemy_workflow_node_execution_repository.py @@ -10,6 +10,10 @@ from concurrent.futures import ThreadPoolExecutor from typing import Any, TypeVar, Union import psycopg2.errors +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import UnaryExpression, asc, desc, select from sqlalchemy.engine import Engine from sqlalchemy.exc import IntegrityError @@ -17,11 +21,7 @@ from sqlalchemy.orm import sessionmaker from tenacity import before_sleep_log, retry, retry_if_exception, stop_after_attempt from configs import dify_config -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.ext_storage import storage from libs.helper import extract_tenant_id from libs.uuid_utils import uuidv7 @@ -518,29 +518,28 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository) return db_models - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, triggered_from: WorkflowNodeExecutionTriggeredFrom = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. This method always queries the database to ensure complete and ordered results, but updates the cache with any retrieved executions. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of node execution instances """ - # Get the database models using the new method - db_models = self.get_db_models_by_workflow_run(workflow_run_id, order_config, triggered_from) + db_models = self.get_db_models_by_workflow_run(workflow_execution_id, order_config, triggered_from) with ThreadPoolExecutor(max_workers=10) as executor: domain_models = executor.map(self._to_domain_model, db_models, timeout=30) diff --git a/api/core/telemetry/__init__.py b/api/core/telemetry/__init__.py new file mode 100644 index 0000000000..ae4f53f3b7 --- /dev/null +++ b/api/core/telemetry/__init__.py @@ -0,0 +1,43 @@ +"""Telemetry facade. + +Thin public API for emitting telemetry events. All routing logic +lives in ``core.telemetry.gateway`` which is shared by both CE and EE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent +from core.telemetry.gateway import emit as gateway_emit +from core.telemetry.gateway import get_trace_task_to_case + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + + +def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None: + """Emit a telemetry event. + + Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a + ``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``. + """ + case = get_trace_task_to_case().get(event.name) + if case is None: + return + + context: dict[str, object] = { + "tenant_id": event.context.tenant_id, + "user_id": event.context.user_id, + "app_id": event.context.app_id, + } + gateway_emit(case, context, event.payload, trace_manager) + + +__all__ = [ + "TelemetryContext", + "TelemetryEvent", + "TraceTaskName", + "emit", +] diff --git a/api/core/telemetry/events.py b/api/core/telemetry/events.py new file mode 100644 index 0000000000..35ace47510 --- /dev/null +++ b/api/core/telemetry/events.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from core.ops.entities.trace_entity import TraceTaskName + + +@dataclass(frozen=True) +class TelemetryContext: + tenant_id: str | None = None + user_id: str | None = None + app_id: str | None = None + + +@dataclass(frozen=True) +class TelemetryEvent: + name: TraceTaskName + context: TelemetryContext + payload: dict[str, Any] diff --git a/api/core/telemetry/gateway.py b/api/core/telemetry/gateway.py new file mode 100644 index 0000000000..7b013d0563 --- /dev/null +++ b/api/core/telemetry/gateway.py @@ -0,0 +1,239 @@ +"""Telemetry gateway — single routing layer for all editions. + +Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either +the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only +metric/log Celery queue. + +This module lives in ``core/`` so both CE and EE share one routing table +and one ``emit()`` entry point. No separate enterprise gateway module is +needed — enterprise-specific dispatch (Celery task, payload offloading) +is handled here behind lazy imports that no-op in CE. +""" + +from __future__ import annotations + +import json +import logging +import uuid +from typing import TYPE_CHECKING, Any + +from core.ops.entities.trace_entity import TraceTaskName +from enterprise.telemetry.contracts import CaseRoute, SignalType +from extensions.ext_storage import storage + +if TYPE_CHECKING: + from core.ops.ops_trace_manager import TraceQueueManager + from enterprise.telemetry.contracts import TelemetryCase + +logger = logging.getLogger(__name__) + +PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024 + +# --------------------------------------------------------------------------- +# Routing table — authoritative mapping for all editions +# --------------------------------------------------------------------------- + +_case_to_trace_task: dict[TelemetryCase, TraceTaskName] | None = None +_case_routing: dict[TelemetryCase, CaseRoute] | None = None + + +def _get_case_to_trace_task() -> dict[TelemetryCase, TraceTaskName]: + global _case_to_trace_task + if _case_to_trace_task is None: + from enterprise.telemetry.contracts import TelemetryCase + + _case_to_trace_task = { + TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE, + TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE, + TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE, + TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE, + TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE, + TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE, + TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE, + TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE, + TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE, + } + return _case_to_trace_task + + +def get_trace_task_to_case() -> dict[TraceTaskName, TelemetryCase]: + """Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task).""" + return {v: k for k, v in _get_case_to_trace_task().items()} + + +def _get_case_routing() -> dict[TelemetryCase, CaseRoute]: + global _case_routing + if _case_routing is None: + from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase + + _case_routing = { + # TRACE — CE-eligible (flow in both CE and EE) + TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True), + # TRACE — enterprise-only + TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False), + # METRIC_LOG — enterprise-only (signal-driven, not trace) + TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False), + } + return _case_routing + + +def __getattr__(name: str) -> dict: + """Lazy module-level access to routing tables.""" + if name == "CASE_ROUTING": + return _get_case_routing() + if name == "CASE_TO_TRACE_TASK": + return _get_case_to_trace_task() + raise AttributeError(f"module {__name__!r} has no attribute {name!r}") + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def is_enterprise_telemetry_enabled() -> bool: + try: + from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled + + return is_enterprise_telemetry_enabled() + except Exception: + return False + + +def _handle_payload_sizing( + payload: dict[str, Any], + tenant_id: str, + event_id: str, +) -> tuple[dict[str, Any], str | None]: + """Inline or offload payload based on size. + + Returns ``(payload_for_envelope, storage_key | None)``. Payloads + exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object + storage and replaced with an empty dict in the envelope. + """ + try: + payload_json = json.dumps(payload) + payload_size = len(payload_json.encode("utf-8")) + except (TypeError, ValueError): + logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id) + return payload, None + + if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES: + return payload, None + + storage_key = f"telemetry/{tenant_id}/{event_id}.json" + try: + storage.save(storage_key, payload_json.encode("utf-8")) + logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size) + return {}, storage_key + except Exception: + logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True) + return payload, None + + +# --------------------------------------------------------------------------- +# Public API +# --------------------------------------------------------------------------- + + +def emit( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None = None, +) -> None: + """Route a telemetry event to the correct pipeline. + + TRACE events are enqueued into ``TraceQueueManager`` (works in both CE + and EE). Enterprise-only traces are silently dropped when EE is + disabled. + + METRIC_LOG events are dispatched to the enterprise Celery queue; + silently dropped when enterprise telemetry is unavailable. + """ + route = _get_case_routing().get(case) + if route is None: + logger.warning("Unknown telemetry case: %s, dropping event", case) + return + + if not route.ce_eligible and not is_enterprise_telemetry_enabled(): + logger.debug("Dropping EE-only event: case=%s (EE disabled)", case) + return + + if route.signal_type == SignalType.TRACE: + _emit_trace(case, context, payload, trace_manager) + else: + _emit_metric_log(case, context, payload) + + +def _emit_trace( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], + trace_manager: TraceQueueManager | None, +) -> None: + from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager + from core.ops.ops_trace_manager import TraceTask + + trace_task_name = _get_case_to_trace_task().get(case) + if trace_task_name is None: + logger.warning("No TraceTaskName mapping for case: %s", case) + return + + queue_manager = trace_manager or LocalTraceQueueManager( + app_id=context.get("app_id"), + user_id=context.get("user_id"), + ) + queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload)) + logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id")) + + +def _emit_metric_log( + case: TelemetryCase, + context: dict[str, Any], + payload: dict[str, Any], +) -> None: + """Build envelope and dispatch to enterprise Celery queue. + + No-ops when the enterprise telemetry task is not importable (CE mode). + """ + try: + from tasks.enterprise_telemetry_task import process_enterprise_telemetry + except ImportError: + logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case) + return + + tenant_id = context.get("tenant_id") or "" + event_id = str(uuid.uuid4()) + + payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id) + + from enterprise.telemetry.contracts import TelemetryEnvelope + + envelope = TelemetryEnvelope( + case=case, + tenant_id=tenant_id, + event_id=event_id, + payload=payload_for_envelope, + metadata={"payload_ref": payload_ref} if payload_ref else None, + ) + + process_enterprise_telemetry.delay(envelope.model_dump_json()) + logger.debug( + "Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s", + case, + tenant_id, + event_id, + ) diff --git a/api/core/tools/__base/tool_runtime.py b/api/core/tools/__base/tool_runtime.py index 961d13f90a..5154bc9805 100644 --- a/api/core/tools/__base/tool_runtime.py +++ b/api/core/tools/__base/tool_runtime.py @@ -9,10 +9,14 @@ from core.tools.entities.tool_entities import ToolInvokeFrom class ToolRuntime(BaseModel): """ - Meta data of a tool call processing + Meta data of a tool call processing. + + ``user_id`` is optional so read-only tooling flows can stay tenant-scoped, + while execution paths may bind caller identity for model runtime lookups. """ tenant_id: str + user_id: str | None = None tool_id: str | None = None invoke_from: InvokeFrom | None = None tool_invoke_from: ToolInvokeFrom | None = None diff --git a/api/core/tools/builtin_tool/providers/audio/tools/asr.py b/api/core/tools/builtin_tool/providers/audio/tools/asr.py index dacc49c746..e539074303 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/asr.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/asr.py @@ -2,14 +2,15 @@ import io from collections.abc import Generator from typing import Any +from graphon.file import FileType +from graphon.file.file_manager import download +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from dify_graph.file.enums import FileType -from dify_graph.file.file_manager import download -from dify_graph.model_runtime.entities.model_entities import ModelType from services.model_provider_service import ModelProviderService @@ -22,6 +23,9 @@ class ASRTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: + if not self.runtime: + raise ValueError("Runtime is required") + runtime = self.runtime file = tool_parameters.get("audio_file") if file.type != FileType.AUDIO: # type: ignore yield self.create_text_message("not a valid audio file") @@ -29,20 +33,19 @@ class ASRTool(BuiltinTool): audio_binary = io.BytesIO(download(file)) # type: ignore audio_binary.name = "temp.mp3" provider, model = tool_parameters.get("model").split("#") # type: ignore - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id, + tenant_id=runtime.tenant_id, provider=provider, model_type=ModelType.SPEECH2TEXT, model=model, ) - text = model_instance.invoke_speech2text( - file=audio_binary, - user=user_id, - ) + text = model_instance.invoke_speech2text(file=audio_binary) yield self.create_text_message(text) def get_available_models(self) -> list[tuple[str, str]]: + if not self.runtime: + raise ValueError("Runtime is required") model_provider_service = ModelProviderService() models = model_provider_service.get_models_by_model_type( tenant_id=self.runtime.tenant_id, model_type="speech2text" diff --git a/api/core/tools/builtin_tool/providers/audio/tools/tts.py b/api/core/tools/builtin_tool/providers/audio/tools/tts.py index 7818bff0ab..f49c669fe0 100644 --- a/api/core/tools/builtin_tool/providers/audio/tools/tts.py +++ b/api/core/tools/builtin_tool/providers/audio/tools/tts.py @@ -2,12 +2,13 @@ import io from collections.abc import Generator from typing import Any +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType + from core.model_manager import ModelManager from core.plugin.entities.parameters import PluginParameterOption from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType from services.model_provider_service import ModelProviderService @@ -20,13 +21,14 @@ class TTSTool(BuiltinTool): app_id: str | None = None, message_id: str | None = None, ) -> Generator[ToolInvokeMessage, None, None]: - provider, model = tool_parameters.get("model").split("#") # type: ignore - voice = tool_parameters.get(f"voice#{provider}#{model}") - model_manager = ModelManager() if not self.runtime: raise ValueError("Runtime is required") + runtime = self.runtime + provider, model = tool_parameters.get("model").split("#") # type: ignore + voice = tool_parameters.get(f"voice#{provider}#{model}") + model_manager = ModelManager.for_tenant(tenant_id=runtime.tenant_id, user_id=user_id) model_instance = model_manager.get_model_instance( - tenant_id=self.runtime.tenant_id or "", + tenant_id=runtime.tenant_id or "", provider=provider, model_type=ModelType.TTS, model=model, @@ -39,12 +41,7 @@ class TTSTool(BuiltinTool): raise ValueError("Sorry, no voice available.") else: raise ValueError("Sorry, no voice available.") - tts = model_instance.invoke_tts( - content_text=tool_parameters.get("text"), # type: ignore - user=user_id, - tenant_id=self.runtime.tenant_id, - voice=voice, - ) + tts = model_instance.invoke_tts(content_text=tool_parameters.get("text"), voice=voice) # type: ignore[arg-type] buffer = io.BytesIO() for chunk in tts: buffer.write(chunk) diff --git a/api/core/tools/builtin_tool/providers/time/tools/current_time.py b/api/core/tools/builtin_tool/providers/time/tools/current_time.py index 44f94c2723..e07ca0d919 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/current_time.py +++ b/api/core/tools/builtin_tool/providers/time/tools/current_time.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import UTC, datetime from typing import Any -from pytz import timezone as pytz_timezone +from pytz import timezone as pytz_timezone # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py index d0a41b940f..dc49b64dd8 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py +++ b/api/core/tools/builtin_tool/providers/time/tools/localtime_to_timestamp.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py index 462e4be5ce..8045e4b980 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timestamp_to_localtime.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py index e23ae3b001..e2570811d6 100644 --- a/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py +++ b/api/core/tools/builtin_tool/providers/time/tools/timezone_conversion.py @@ -2,7 +2,7 @@ from collections.abc import Generator from datetime import datetime from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.tool_entities import ToolInvokeMessage diff --git a/api/core/tools/builtin_tool/tool.py b/api/core/tools/builtin_tool/tool.py index bcf58394ba..14af63a962 100644 --- a/api/core/tools/builtin_tool/tool.py +++ b/api/core/tools/builtin_tool/tool.py @@ -1,11 +1,12 @@ from __future__ import annotations +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage + from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolProviderType from core.tools.utils.model_invocation_utils import ModelInvocationUtils -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage _SUMMARY_PROMPT = """You are a professional language researcher, you are interested in the language and you can quickly aimed at the main point of an webpage and reproduce it in your own words but @@ -53,6 +54,7 @@ class BuiltinTool(Tool): tool_type=ToolProviderType.BUILT_IN, tool_name=self.entity.identity.name, prompt_messages=prompt_messages, + caller_user_id=self.runtime.user_id, ) def tool_provider_type(self) -> ToolProviderType: @@ -69,6 +71,7 @@ class BuiltinTool(Tool): return ModelInvocationUtils.get_max_llm_context_tokens( tenant_id=self.runtime.tenant_id or "", + user_id=self.runtime.user_id, ) def get_prompt_tokens(self, prompt_messages: list[PromptMessage]) -> int: @@ -82,7 +85,9 @@ class BuiltinTool(Tool): raise ValueError("runtime is required") return ModelInvocationUtils.calculate_tokens( - tenant_id=self.runtime.tenant_id or "", prompt_messages=prompt_messages + tenant_id=self.runtime.tenant_id or "", + prompt_messages=prompt_messages, + user_id=self.runtime.user_id, ) def summary(self, user_id: str, content: str) -> str: diff --git a/api/core/tools/custom_tool/tool.py b/api/core/tools/custom_tool/tool.py index c6a84e27c6..0a2c37c563 100644 --- a/api/core/tools/custom_tool/tool.py +++ b/api/core/tools/custom_tool/tool.py @@ -6,6 +6,7 @@ from typing import Any, Union from urllib.parse import urlencode import httpx +from graphon.file.file_manager import download from core.helper import ssrf_proxy from core.tools.__base.tool import Tool @@ -13,7 +14,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError, ToolParameterValidationError, ToolProviderCredentialValidationError -from dify_graph.file.file_manager import download API_TOOL_DEFAULT_TIMEOUT = ( int(getenv("API_TOOL_DEFAULT_CONNECT_TIMEOUT", "10")), diff --git a/api/core/tools/entities/api_entities.py b/api/core/tools/entities/api_entities.py index 2545290b57..d5d3d1b1d9 100644 --- a/api/core/tools/entities/api_entities.py +++ b/api/core/tools/entities/api_entities.py @@ -2,6 +2,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any, Literal +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel, Field, field_validator from core.entities.mcp_provider import MCPAuthentication, MCPConfiguration @@ -9,7 +10,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import ToolParameter from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.model_runtime.utils.encoders import jsonable_encoder class ToolApiEntity(BaseModel): diff --git a/api/core/tools/mcp_tool/tool.py b/api/core/tools/mcp_tool/tool.py index 9025ff6ef1..f6d09472b3 100644 --- a/api/core/tools/mcp_tool/tool.py +++ b/api/core/tools/mcp_tool/tool.py @@ -6,6 +6,8 @@ import logging from collections.abc import Generator, Mapping from typing import Any, cast +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata + from core.mcp.auth_client import MCPClientWithAuthRetry from core.mcp.error import MCPConnectionError from core.mcp.types import ( @@ -21,7 +23,6 @@ from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ToolEntity, ToolInvokeMessage, ToolProviderType from core.tools.errors import ToolInvokeError -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata logger = logging.getLogger(__name__) diff --git a/api/core/tools/signature.py b/api/core/tools/signature.py index 22e099deba..1807226924 100644 --- a/api/core/tools/signature.py +++ b/api/core/tools/signature.py @@ -3,6 +3,7 @@ import hashlib import hmac import os import time +import urllib.parse from configs import dify_config @@ -58,3 +59,43 @@ def verify_tool_file_signature(file_id: str, timestamp: str, nonce: str, sign: s current_time = int(time.time()) return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT + + +def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: + """Build the signed upload URL used by the plugin-facing file upload endpoint.""" + + base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL + upload_url = f"{base_url}/files/upload/for-plugin" + timestamp = str(int(time.time())) + nonce = os.urandom(16).hex() + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + encoded_sign = base64.urlsafe_b64encode(sign).decode() + query = urllib.parse.urlencode( + { + "timestamp": timestamp, + "nonce": nonce, + "sign": encoded_sign, + "user_id": user_id, + "tenant_id": tenant_id, + } + ) + return f"{upload_url}?{query}" + + +def verify_plugin_file_signature( + *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str +) -> bool: + """Verify the signature used by the plugin-facing file upload endpoint.""" + + data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" + secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" + recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() + recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() + + if sign != recalculated_encoded_sign: + return False + + current_time = int(time.time()) + return current_time - int(timestamp) <= dify_config.FILES_ACCESS_TIMEOUT diff --git a/api/core/tools/tool_engine.py b/api/core/tools/tool_engine.py index 64212a2636..685d687d8c 100644 --- a/api/core/tools/tool_engine.py +++ b/api/core/tools/tool_engine.py @@ -7,6 +7,7 @@ from datetime import UTC, datetime from mimetypes import guess_type from typing import Any, Union, cast +from graphon.file import FileTransferMethod, FileType from yarl import URL from core.app.entities.app_invoke_entities import InvokeFrom @@ -31,8 +32,6 @@ from core.tools.errors import ( ) from core.tools.utils.message_transformer import ToolFileMessageTransformer, safe_json_value from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FileType -from dify_graph.file.models import FileTransferMethod from extensions.ext_database import db from models.enums import CreatorUserRole, MessageFileBelongsTo from models.model import Message, MessageFile diff --git a/api/core/tools/tool_file_manager.py b/api/core/tools/tool_file_manager.py index 210f488afc..7ac29cf069 100644 --- a/api/core/tools/tool_file_manager.py +++ b/api/core/tools/tool_file_manager.py @@ -10,11 +10,12 @@ from typing import Union from uuid import uuid4 import httpx +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type from configs import dify_config from core.db.session_factory import session_factory from core.helper import ssrf_proxy -from dify_graph.file.models import ToolFile as ToolFilePydanticModel +from core.workflow.file_reference import build_file_reference from extensions.ext_storage import storage from models.model import MessageFile from models.tools import ToolFile @@ -23,6 +24,21 @@ logger = logging.getLogger(__name__) class ToolFileManager: + @staticmethod + def _build_graph_file_reference(tool_file: ToolFile) -> File: + extension = guess_extension(tool_file.mimetype) or ".bin" + return File( + type=get_file_type_by_mime_type(tool_file.mimetype), + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + filename=tool_file.name, + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + @staticmethod def sign_file(tool_file_id: str, extension: str) -> str: """ @@ -209,9 +225,7 @@ class ToolFileManager: return blob, tool_file.mimetype - def get_file_generator_by_tool_file_id( - self, tool_file_id: str - ) -> tuple[Generator | None, ToolFilePydanticModel | None]: + def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, File | None]: """ get file binary @@ -233,11 +247,11 @@ class ToolFileManager: stream = storage.load_stream(tool_file.file_key) - return stream, ToolFilePydanticModel.model_validate(tool_file) + return stream, self._build_graph_file_reference(tool_file) # init tool_file_parser -from dify_graph.file.tool_file_parser import set_tool_file_manager_factory +from graphon.file.tool_file_parser import set_tool_file_manager_factory def _factory() -> ToolFileManager: diff --git a/api/core/tools/tool_label_manager.py b/api/core/tools/tool_label_manager.py index 250dd91bfd..58190d1089 100644 --- a/api/core/tools/tool_label_manager.py +++ b/api/core/tools/tool_label_manager.py @@ -1,4 +1,4 @@ -from sqlalchemy import select +from sqlalchemy import delete, select from core.tools.__base.tool_provider import ToolProviderController from core.tools.builtin_tool.provider import BuiltinToolProviderController @@ -31,7 +31,7 @@ class ToolLabelManager: raise ValueError("Unsupported tool type") # delete old labels - db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete() + db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id)) # insert new labels for label in labels: diff --git a/api/core/tools/tool_manager.py b/api/core/tools/tool_manager.py index 23a877b7e3..a58d310313 100644 --- a/api/core/tools/tool_manager.py +++ b/api/core/tools/tool_manager.py @@ -5,9 +5,10 @@ import time from collections.abc import Generator, Mapping from os import listdir, path from threading import Lock -from typing import TYPE_CHECKING, Any, Literal, Optional, TypedDict, Union, cast +from typing import TYPE_CHECKING, Any, Literal, Optional, Protocol, TypedDict, Union, cast import sqlalchemy as sa +from graphon.runtime import VariablePool from sqlalchemy import select from sqlalchemy.orm import Session from yarl import URL @@ -24,20 +25,20 @@ from core.tools.plugin_tool.provider import PluginToolProviderController from core.tools.plugin_tool.tool import PluginTool from core.tools.utils.uuid_utils import is_valid_uuid from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from dify_graph.runtime.variable_pool import VariablePool from extensions.ext_database import db from models.provider_ids import ToolProviderID from services.enterprise.plugin_manager_service import PluginCredentialType from services.tools.mcp_tools_manage_service import MCPToolManageService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass + +from graphon.model_runtime.utils.encoders import jsonable_encoder from core.agent.entities import AgentToolEntity from core.app.entities.app_invoke_entities import InvokeFrom from core.helper.module_import_helper import load_single_subclass_from_source from core.helper.position_helper import is_filtered -from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool import Tool from core.tools.builtin_tool.provider import BuiltinToolProviderController from core.tools.builtin_tool.providers._positions import BuiltinToolProviderSort @@ -57,12 +58,11 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.configuration import ToolParameterConfigurationManager from core.tools.utils.encryption import create_provider_encrypter, create_tool_provider_encrypter from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from models.tools import ApiToolProvider, BuiltinToolProvider, WorkflowToolProvider from services.tools.tools_transform_service import ToolTransformService if TYPE_CHECKING: - from dify_graph.nodes.tool.entities import ToolEntity + pass logger = logging.getLogger(__name__) @@ -77,6 +77,23 @@ class EmojiIconDict(TypedDict): content: str +class WorkflowToolRuntimeSpec(Protocol): + @property + def provider_type(self) -> ToolProviderType: ... + + @property + def provider_id(self) -> str: ... + + @property + def tool_name(self) -> str: ... + + @property + def tool_configurations(self) -> Mapping[str, Any]: ... + + @property + def credential_id(self) -> str | None: ... + + class ToolManager: _builtin_provider_lock = Lock() _hardcoded_providers: dict[str, BuiltinToolProviderController] = {} @@ -167,6 +184,7 @@ class ToolManager: provider_id: str, tool_name: str, tenant_id: str, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, tool_invoke_from: ToolInvokeFrom = ToolInvokeFrom.AGENT, credential_id: str | None = None, @@ -178,6 +196,7 @@ class ToolManager: :param provider_id: the id of the provider :param tool_name: the name of the tool :param tenant_id: the tenant id + :param user_id: the caller id bound to runtime-scoped model/tool lookups :param invoke_from: invoke from :param tool_invoke_from: the tool invoke from :param credential_id: the credential id @@ -196,6 +215,7 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -235,11 +255,11 @@ class ToolManager: if builtin_provider is None: raise ToolProviderNotFoundError(f"no default provider for {provider_id}") else: - builtin_provider = ( - db.session.query(BuiltinToolProvider) + builtin_provider = db.session.scalar( + select(BuiltinToolProvider) .where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id)) .order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc()) - .first() + .limit(1) ) if builtin_provider is None: @@ -304,8 +324,9 @@ class ToolManager: return builtin_tool.fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(decrypted_credentials), - credential_type=CredentialType.of(builtin_provider.credential_type), + credential_type=builtin_provider.credential_type, runtime_parameters={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -321,6 +342,7 @@ class ToolManager: return api_provider.get_tool(tool_name).fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials=dict(encrypter.decrypt(credentials)), invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -344,6 +366,7 @@ class ToolManager: return controller.get_tools(tenant_id=workflow_provider.tenant_id)[0].fork_tool_runtime( runtime=ToolRuntime( tenant_id=tenant_id, + user_id=user_id, credentials={}, invoke_from=invoke_from, tool_invoke_from=tool_invoke_from, @@ -352,9 +375,21 @@ class ToolManager: elif provider_type == ToolProviderType.APP: raise NotImplementedError("app provider not implemented") elif provider_type == ToolProviderType.PLUGIN: - return cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + plugin_tool = cls.get_plugin_provider(provider_id, tenant_id).get_tool(tool_name) + runtime = getattr(plugin_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return plugin_tool elif provider_type == ToolProviderType.MCP: - return cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + mcp_tool = cls.get_mcp_provider_controller(tenant_id, provider_id).get_tool(tool_name) + runtime = getattr(mcp_tool, "runtime", None) + if runtime is not None: + runtime.user_id = user_id + runtime.invoke_from = invoke_from + runtime.tool_invoke_from = tool_invoke_from + return mcp_tool else: raise ToolProviderNotFoundError(f"provider type {provider_type.value} not found") @@ -364,6 +399,7 @@ class ToolManager: tenant_id: str, app_id: str, agent_tool: AgentToolEntity, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -375,6 +411,7 @@ class ToolManager: provider_id=agent_tool.provider_id, tool_name=agent_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.AGENT, credential_id=agent_tool.credential_id, @@ -405,7 +442,8 @@ class ToolManager: tenant_id: str, app_id: str, node_id: str, - workflow_tool: "ToolEntity", + workflow_tool: WorkflowToolRuntimeSpec, + user_id: str | None = None, invoke_from: InvokeFrom = InvokeFrom.DEBUGGER, variable_pool: Optional["VariablePool"] = None, ) -> Tool: @@ -418,6 +456,7 @@ class ToolManager: provider_id=workflow_tool.provider_id, tool_name=workflow_tool.tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=invoke_from, tool_invoke_from=ToolInvokeFrom.WORKFLOW, credential_id=workflow_tool.credential_id, @@ -450,6 +489,7 @@ class ToolManager: provider: str, tool_name: str, tool_parameters: dict[str, Any], + user_id: str | None = None, credential_id: str | None = None, ) -> Tool: """ @@ -460,6 +500,7 @@ class ToolManager: provider_id=provider, tool_name=tool_name, tenant_id=tenant_id, + user_id=user_id, invoke_from=InvokeFrom.SERVICE_API, tool_invoke_from=ToolInvokeFrom.PLUGIN, credential_id=credential_id, @@ -777,13 +818,13 @@ class ToolManager: :return: the provider controller, the credentials """ - provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.id == provider_id, ApiToolProvider.tenant_id == tenant_id, ) - .first() + .limit(1) ) if provider is None: @@ -831,13 +872,13 @@ class ToolManager: get api provider """ provider_name = provider - provider_obj: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + provider_obj: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where( ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.name == provider, ) - .first() + .limit(1) ) if provider_obj is None: @@ -923,10 +964,10 @@ class ToolManager: @classmethod def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - workflow_provider: WorkflowToolProvider | None = ( - db.session.query(WorkflowToolProvider) + workflow_provider: WorkflowToolProvider | None = db.session.scalar( + select(WorkflowToolProvider) .where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id) - .first() + .limit(1) ) if workflow_provider is None: @@ -940,10 +981,10 @@ class ToolManager: @classmethod def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict: try: - api_provider: ApiToolProvider | None = ( - db.session.query(ApiToolProvider) + api_provider: ApiToolProvider | None = db.session.scalar( + select(ApiToolProvider) .where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id) - .first() + .limit(1) ) if api_provider is None: @@ -1015,14 +1056,14 @@ class ToolManager: cls, parameters: list[ToolParameter], variable_pool: Optional["VariablePool"], - tool_configurations: dict[str, Any], + tool_configurations: Mapping[str, Any], typ: Literal["agent", "workflow", "tool"] = "workflow", ) -> dict[str, Any]: """ Convert tool parameters type """ - from dify_graph.nodes.tool.entities import ToolNodeData - from dify_graph.nodes.tool.exc import ToolParameterError + from graphon.nodes.tool.entities import ToolNodeData + from graphon.nodes.tool.exc import ToolParameterError runtime_parameters = {} for parameter in parameters: diff --git a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py index 75b923fd8b..e63435db98 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_multi_retriever_tool.py @@ -1,6 +1,7 @@ import threading from flask import Flask, current_app +from graphon.model_runtime.entities.model_entities import ModelType from pydantic import BaseModel, Field from sqlalchemy import select @@ -14,7 +15,6 @@ from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool from core.tools.utils.dataset_retriever.dataset_retriever_tool import DefaultRetrievalModelDict -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.dataset import Dataset, Document, DocumentSegment @@ -66,7 +66,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): for thread in threads: thread.join() # do rerank for searched documents - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=self.tenant_id) rerank_model_instance = model_manager.get_model_instance( tenant_id=self.tenant_id, provider=self.reranking_provider_name, @@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool): context_list: list[RetrievalSourceMetadata] = [] resource_number = 1 for segment in sorted_segments: - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) document_stmt = select(Document).where( Document.id == segment.document_id, Document.enabled == True, diff --git a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py index f3d390ed59..cbd8bdb36c 100644 --- a/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py +++ b/api/core/tools/utils/dataset_retriever/dataset_retriever_tool.py @@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool): if self.return_resource: for record in records: segment = record.segment - dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first() + dataset = db.session.get(Dataset, segment.dataset_id) dataset_document_stmt = select(DatasetDocument).where( DatasetDocument.id == segment.document_id, DatasetDocument.enabled == True, diff --git a/api/core/tools/utils/message_transformer.py b/api/core/tools/utils/message_transformer.py index 6fc5fead2d..bb5b3ba76e 100644 --- a/api/core/tools/utils/message_transformer.py +++ b/api/core/tools/utils/message_transformer.py @@ -1,4 +1,5 @@ import logging +import re from collections.abc import Generator from datetime import date, datetime from decimal import Decimal @@ -7,15 +8,18 @@ from uuid import UUID import numpy as np import pytz +from graphon.file import File, FileTransferMethod, FileType from core.tools.entities.tool_entities import ToolInvokeMessage from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType +from core.workflow.file_reference import parse_file_reference from libs.login import current_user from models import Account logger = logging.getLogger(__name__) +_TOOL_FILE_URL_PATTERN = re.compile(r"(?:^|/+)files/tools/(?P[^/?#.]+)") + def safe_json_value(v): if isinstance(v, datetime): @@ -82,11 +86,15 @@ class ToolFileMessageTransformer: ) url = f"/files/tools/{tool_file.id}{guess_extension(tool_file.mimetype) or '.png'}" + meta = cls._with_tool_file_meta( + message.meta, + tool_file_id=str(tool_file.id), + ) yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=message.meta.copy() if message.meta is not None else {}, + meta=meta, ) except Exception as e: yield ToolInvokeMessage( @@ -122,38 +130,45 @@ class ToolFileMessageTransformer: ) url = cls.get_tool_file_url(tool_file_id=tool_file.id, extension=guess_extension(tool_file.mimetype)) + meta = cls._with_tool_file_meta(meta, tool_file_id=str(tool_file.id)) # check if file is image if "image" in mimetype: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.BINARY_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=meta, ) elif message.type == ToolInvokeMessage.MessageType.FILE: meta = message.meta or {} file = meta.get("file", None) if isinstance(file, File): if file.transfer_method == FileTransferMethod.TOOL_FILE: - assert file.related_id is not None - url = cls.get_tool_file_url(tool_file_id=file.related_id, extension=file.extension) + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("tool file is missing reference") + url = cls.get_tool_file_url( + tool_file_id=parsed_reference.record_id, + extension=file.extension, + ) + tool_file_meta = cls._with_tool_file_meta(meta, tool_file_id=parsed_reference.record_id) if file.type == FileType.IMAGE: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.IMAGE_LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield ToolInvokeMessage( type=ToolInvokeMessage.MessageType.LINK, message=ToolInvokeMessage.TextMessage(text=url), - meta=meta.copy() if meta is not None else {}, + meta=tool_file_meta, ) else: yield message @@ -162,9 +177,40 @@ class ToolFileMessageTransformer: if isinstance(message.message, ToolInvokeMessage.JsonMessage): message.message.json_object = safe_json_value(message.message.json_object) yield message + elif message.type in { + ToolInvokeMessage.MessageType.IMAGE_LINK, + ToolInvokeMessage.MessageType.BINARY_LINK, + } and isinstance(message.message, ToolInvokeMessage.TextMessage): + yield ToolInvokeMessage( + type=message.type, + message=message.message, + meta=cls._with_tool_file_meta(message.meta, url=message.message.text), + ) else: yield message @classmethod def get_tool_file_url(cls, tool_file_id: str, extension: str | None) -> str: return f"/files/tools/{tool_file_id}{extension or '.bin'}" + + @staticmethod + def _with_tool_file_meta( + meta: dict | None, + *, + tool_file_id: str | None = None, + url: str | None = None, + ) -> dict: + normalized_meta = meta.copy() if meta is not None else {} + resolved_tool_file_id = tool_file_id or ToolFileMessageTransformer._extract_tool_file_id(url) + if resolved_tool_file_id and "tool_file_id" not in normalized_meta: + normalized_meta["tool_file_id"] = resolved_tool_file_id + return normalized_meta + + @staticmethod + def _extract_tool_file_id(url: str | None) -> str | None: + if not url: + return None + match = _TOOL_FILE_URL_PATTERN.search(url) + if match is None: + return None + return match.group("tool_file_id") diff --git a/api/core/tools/utils/model_invocation_utils.py b/api/core/tools/utils/model_invocation_utils.py index 373bd1b1c8..8d6f83dc07 100644 --- a/api/core/tools/utils/model_invocation_utils.py +++ b/api/core/tools/utils/model_invocation_utils.py @@ -8,20 +8,21 @@ import json from decimal import Decimal from typing import cast -from core.model_manager import ModelManager -from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.model_runtime.entities.llm_entities import LLMResult -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.llm_entities import LLMResult +from graphon.model_runtime.entities.message_entities import PromptMessage +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, InvokeRateLimitError, InvokeServerUnavailableError, ) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.utils.encoders import jsonable_encoder +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.utils.encoders import jsonable_encoder + +from core.model_manager import ModelManager +from core.tools.entities.tool_entities import ToolProviderType from extensions.ext_database import db from models.tools import ToolModelInvoke @@ -34,11 +35,12 @@ class ModelInvocationUtils: @staticmethod def get_max_llm_context_tokens( tenant_id: str, + user_id: str | None = None, ) -> int: """ get max llm context tokens of the model """ - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, model_type=ModelType.LLM, @@ -60,13 +62,13 @@ class ModelInvocationUtils: return max_tokens @staticmethod - def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage]) -> int: + def calculate_tokens(tenant_id: str, prompt_messages: list[PromptMessage], user_id: str | None = None) -> int: """ calculate tokens from prompt messages and model parameters """ # get model instance - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=user_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.LLM) if not model_instance: @@ -79,7 +81,12 @@ class ModelInvocationUtils: @staticmethod def invoke( - user_id: str, tenant_id: str, tool_type: ToolProviderType, tool_name: str, prompt_messages: list[PromptMessage] + user_id: str, + tenant_id: str, + tool_type: ToolProviderType, + tool_name: str, + prompt_messages: list[PromptMessage], + caller_user_id: str | None = None, ) -> LLMResult: """ invoke model with parameters in user's own context @@ -93,7 +100,7 @@ class ModelInvocationUtils: """ # get model manager - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id, user_id=caller_user_id or user_id) # get model instance model_instance = model_manager.get_default_model_instance( tenant_id=tenant_id, @@ -137,7 +144,6 @@ class ModelInvocationUtils: tools=[], stop=[], stream=False, - user=user_id, callbacks=[], ) except InvokeRateLimitError as e: diff --git a/api/core/tools/utils/workflow_configuration_sync.py b/api/core/tools/utils/workflow_configuration_sync.py index 28f1376655..c4b7d57449 100644 --- a/api/core/tools/utils/workflow_configuration_sync.py +++ b/api/core/tools/utils/workflow_configuration_sync.py @@ -1,11 +1,12 @@ from collections.abc import Mapping, Sequence from typing import Any +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.variables.input_entities import VariableEntity + from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.variables.input_entities import VariableEntity class WorkflowToolConfigurationUtils: diff --git a/api/core/tools/workflow_as_tool/provider.py b/api/core/tools/workflow_as_tool/provider.py index aef8b3f779..f48b24be30 100644 --- a/api/core/tools/workflow_as_tool/provider.py +++ b/api/core/tools/workflow_as_tool/provider.py @@ -2,6 +2,7 @@ from __future__ import annotations from collections.abc import Mapping +from graphon.variables.input_entities import VariableEntity, VariableEntityType from pydantic import Field from sqlalchemy.orm import Session @@ -22,7 +23,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from extensions.ext_database import db from models.account import Account from models.model import App, AppMode diff --git a/api/core/tools/workflow_as_tool/tool.py b/api/core/tools/workflow_as_tool/tool.py index 9b9aa7a741..a3fb4eda92 100644 --- a/api/core/tools/workflow_as_tool/tool.py +++ b/api/core/tools/workflow_as_tool/tool.py @@ -5,8 +5,11 @@ import logging from collections.abc import Generator, Mapping, Sequence from typing import Any, cast +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata from sqlalchemy import select +from core.app.file_access import DatabaseFileAccessController from core.db.session_factory import session_factory from core.tools.__base.tool import Tool from core.tools.__base.tool_runtime import ToolRuntime @@ -17,14 +20,15 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.errors import ToolInvokeError -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from core.workflow.file_reference import resolve_file_record_id from factories.file_factory import build_from_mapping from models import Account, Tenant from models.model import App, EndUser +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WorkflowTool(Tool): @@ -288,16 +292,25 @@ class WorkflowTool(Tool): file = tool_parameters.get(parameter.name) if file: try: - file_var_list = [File.model_validate(f) for f in file] + file_var_list = [ + build_file_from_stored_mapping( + file_mapping=cast(Mapping[str, Any], f), + tenant_id=str(self.runtime.tenant_id), + ) + for f in file + if isinstance(f, Mapping) + ] for file in file_var_list: file_dict: dict[str, str | None] = { "transfer_method": file.transfer_method.value, "type": file.type.value, } if file.transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file.related_id + file_dict["tool_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file.related_id + file_dict["upload_file_id"] = resolve_file_record_id(file.reference) + elif file.transfer_method == FileTransferMethod.DATASOURCE_FILE: + file_dict["datasource_file_id"] = resolve_file_record_id(file.reference) elif file.transfer_method == FileTransferMethod.REMOTE_URL: file_dict["url"] = file.generate_url() @@ -325,6 +338,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=item, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) elif isinstance(value, dict) and value.get("dify_model_identity") == FILE_MODEL_IDENTITY: @@ -332,6 +346,7 @@ class WorkflowTool(Tool): file = build_from_mapping( mapping=value, tenant_id=str(self.runtime.tenant_id), + access_controller=_file_access_controller, ) files.append(file) @@ -340,9 +355,10 @@ class WorkflowTool(Tool): return result, files def _update_file_mapping(self, file_dict: dict): + file_id = resolve_file_record_id(file_dict.get("reference") or file_dict.get("related_id")) transfer_method = FileTransferMethod.value_of(file_dict.get("transfer_method")) if transfer_method == FileTransferMethod.TOOL_FILE: - file_dict["tool_file_id"] = file_dict.get("related_id") + file_dict["tool_file_id"] = file_id elif transfer_method == FileTransferMethod.LOCAL_FILE: - file_dict["upload_file_id"] = file_dict.get("related_id") + file_dict["upload_file_id"] = file_id return file_dict diff --git a/api/core/trigger/debug/event_selectors.py b/api/core/trigger/debug/event_selectors.py index 2a133b2b94..61d1cd8540 100644 --- a/api/core/trigger/debug/event_selectors.py +++ b/api/core/trigger/debug/event_selectors.py @@ -8,6 +8,7 @@ from collections.abc import Mapping from datetime import datetime from typing import Any +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from core.plugin.entities.request import TriggerInvokeEventResponse @@ -26,7 +27,6 @@ from core.trigger.debug.events import ( ) from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData from core.workflow.nodes.trigger_schedule.entities import ScheduleConfig -from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_redis import redis_client from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.schedule_utils import calculate_next_run_at diff --git a/api/core/workflow/file_reference.py b/api/core/workflow/file_reference.py new file mode 100644 index 0000000000..c80acb3783 --- /dev/null +++ b/api/core/workflow/file_reference.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +import base64 +import json +from dataclasses import dataclass + +_FILE_REFERENCE_PREFIX = "dify-file-ref:" + + +@dataclass(frozen=True) +class FileReference: + record_id: str + storage_key: str | None = None + + +def build_file_reference(*, record_id: str, storage_key: str | None = None) -> str: + payload = {"record_id": record_id} + if storage_key is not None: + payload["storage_key"] = storage_key + encoded_payload = base64.urlsafe_b64encode(json.dumps(payload, separators=(",", ":")).encode()).decode() + return f"{_FILE_REFERENCE_PREFIX}{encoded_payload}" + + +def parse_file_reference(reference: str | None) -> FileReference | None: + if not reference: + return None + + if not reference.startswith(_FILE_REFERENCE_PREFIX): + return FileReference(record_id=reference) + + encoded_payload = reference.removeprefix(_FILE_REFERENCE_PREFIX) + try: + payload = json.loads(base64.urlsafe_b64decode(encoded_payload.encode())) + except (ValueError, json.JSONDecodeError): + return FileReference(record_id=reference) + + record_id = payload.get("record_id") + if not isinstance(record_id, str) or not record_id: + return FileReference(record_id=reference) + + storage_key = payload.get("storage_key") + if storage_key is not None and not isinstance(storage_key, str): + storage_key = None + + return FileReference(record_id=record_id, storage_key=storage_key) + + +def resolve_file_record_id(reference: str | None) -> str | None: + parsed_reference = parse_file_reference(reference) + if parsed_reference is None: + return None + return parsed_reference.record_id diff --git a/api/core/workflow/human_input_compat.py b/api/core/workflow/human_input_compat.py new file mode 100644 index 0000000000..c95516a240 --- /dev/null +++ b/api/core/workflow/human_input_compat.py @@ -0,0 +1,298 @@ +"""Workflow-layer adapters for legacy human-input payload keys. + +Stored workflow graphs and editor payloads may still use Dify-specific human +input recipient keys. Normalize them here before handing configs to +`graphon` so graph-owned models only see graph-neutral field names. +""" + +from __future__ import annotations + +import enum +import uuid +from collections.abc import Mapping, Sequence +from typing import Annotated, Any, ClassVar, Literal + +import bleach +import markdown +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.runtime import VariablePool +from graphon.variables.consts import SELECTORS_LENGTH +from markdown.extensions.tables import TableExtension +from pydantic import AliasChoices, BaseModel, ConfigDict, Field, TypeAdapter + + +class DeliveryMethodType(enum.StrEnum): + WEBAPP = enum.auto() + EMAIL = enum.auto() + + +class EmailRecipientType(enum.StrEnum): + BOUND = "member" + MEMBER = BOUND + EXTERNAL = "external" + + +class _InteractiveSurfaceDeliveryConfig(BaseModel): + pass + + +class BoundRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.BOUND] = EmailRecipientType.BOUND + reference_id: str + + +class ExternalRecipient(BaseModel): + model_config = ConfigDict(extra="forbid") + + type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL + email: str + + +MemberRecipient = BoundRecipient +EmailRecipient = Annotated[BoundRecipient | ExternalRecipient, Field(discriminator="type")] + + +class EmailRecipients(BaseModel): + model_config = ConfigDict(extra="forbid") + + include_bound_group: bool = Field( + default=False, + validation_alias=AliasChoices("include_bound_group", "whole_workspace"), + ) + items: list[EmailRecipient] = Field(default_factory=list) + + +class EmailDeliveryConfig(BaseModel): + URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" + _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ + "a", + "br", + "code", + "em", + "li", + "ol", + "p", + "pre", + "strong", + "table", + "tbody", + "td", + "th", + "thead", + "tr", + "ul", + ] + _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { + "a": ["href", "title"], + "td": ["align"], + "th": ["align"], + } + _ALLOWED_PROTOCOLS: ClassVar[set[str]] = set(bleach.sanitizer.ALLOWED_PROTOCOLS) | {"mailto"} + + recipients: EmailRecipients + subject: str + body: str + debug_mode: bool = False + + def with_recipients(self, recipients: EmailRecipients) -> EmailDeliveryConfig: + return self.model_copy(update={"recipients": recipients}) + + @classmethod + def replace_url_placeholder(cls, body: str, url: str | None) -> str: + return body.replace(cls.URL_PLACEHOLDER, url or "") + + @classmethod + def render_body_template( + cls, + *, + body: str, + url: str | None, + variable_pool: VariablePool | None = None, + ) -> str: + templated_body = cls.replace_url_placeholder(body, url) + if variable_pool is None: + return templated_body + return variable_pool.convert_template(templated_body).text + + @classmethod + def render_markdown_body(cls, body: str) -> str: + stripped_body = bleach.clean(body, tags=[], attributes={}, strip=True) + rendered = markdown.markdown( + stripped_body, + extensions=[TableExtension(use_align_attribute=True)], + output_format="html", + ) + return bleach.clean( + rendered, + tags=cls._ALLOWED_HTML_TAGS, + attributes=cls._ALLOWED_HTML_ATTRIBUTES, + protocols=cls._ALLOWED_PROTOCOLS, + strip=True, + ) + + @staticmethod + def sanitize_subject(subject: str) -> str: + sanitized = subject.replace("\r", " ").replace("\n", " ") + sanitized = bleach.clean(sanitized, tags=[], strip=True) + return " ".join(sanitized.split()) + + +class _DeliveryMethodBase(BaseModel): + enabled: bool = True + id: uuid.UUID = Field(default_factory=uuid.uuid4) + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + return () + + +class InteractiveSurfaceDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP + config: _InteractiveSurfaceDeliveryConfig = Field(default_factory=_InteractiveSurfaceDeliveryConfig) + + +class EmailDeliveryMethod(_DeliveryMethodBase): + type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL + config: EmailDeliveryConfig + + def extract_variable_selectors(self) -> Sequence[Sequence[str]]: + variable_template_parser = VariableTemplateParser(template=self.config.body) + selectors: list[Sequence[str]] = [] + for variable_selector in variable_template_parser.extract_variable_selectors(): + value_selector = list(variable_selector.value_selector) + if len(value_selector) < SELECTORS_LENGTH: + continue + selectors.append(value_selector[:SELECTORS_LENGTH]) + return selectors + + +WebAppDeliveryMethod = InteractiveSurfaceDeliveryMethod +_WebAppDeliveryConfig = _InteractiveSurfaceDeliveryConfig + +DeliveryChannelConfig = Annotated[InteractiveSurfaceDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] + +_DELIVERY_METHODS_ADAPTER = TypeAdapter(list[DeliveryChannelConfig]) + + +def _copy_mapping(value: object) -> dict[str, Any] | None: + if isinstance(value, BaseModel): + return value.model_dump(mode="python") + if isinstance(value, Mapping): + return dict(value) + return None + + +def normalize_human_input_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"human-input node data must be a mapping, got {type(node_data).__name__}") + + delivery_methods = normalized.get("delivery_methods") + if not isinstance(delivery_methods, list): + return normalized + + normalized_methods: list[Any] = [] + for method in delivery_methods: + method_mapping = _copy_mapping(method) + if method_mapping is None: + normalized_methods.append(method) + continue + + config_mapping = _copy_mapping(method_mapping.get("config")) + if config_mapping is not None: + recipients_mapping = _copy_mapping(config_mapping.get("recipients")) + if recipients_mapping is not None: + config_mapping["recipients"] = _normalize_email_recipients(recipients_mapping) + method_mapping["config"] = config_mapping + + normalized_methods.append(method_mapping) + + normalized["delivery_methods"] = normalized_methods + return normalized + + +def parse_human_input_delivery_methods(node_data: Mapping[str, Any] | BaseModel) -> list[DeliveryChannelConfig]: + normalized = normalize_human_input_node_data_for_graph(node_data) + raw_delivery_methods = normalized.get("delivery_methods") + if not isinstance(raw_delivery_methods, list): + return [] + return list(_DELIVERY_METHODS_ADAPTER.validate_python(raw_delivery_methods)) + + +def is_human_input_webapp_enabled(node_data: Mapping[str, Any] | BaseModel) -> bool: + for method in parse_human_input_delivery_methods(node_data): + if method.enabled and method.type == DeliveryMethodType.WEBAPP: + return True + return False + + +def normalize_node_data_for_graph(node_data: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_data) + if normalized is None: + raise TypeError(f"node data must be a mapping, got {type(node_data).__name__}") + + if normalized.get("type") != BuiltinNodeTypes.HUMAN_INPUT: + return normalized + return normalize_human_input_node_data_for_graph(normalized) + + +def normalize_node_config_for_graph(node_config: Mapping[str, Any] | BaseModel) -> dict[str, Any]: + normalized = _copy_mapping(node_config) + if normalized is None: + raise TypeError(f"node config must be a mapping, got {type(node_config).__name__}") + + data_mapping = _copy_mapping(normalized.get("data")) + if data_mapping is None: + return normalized + + normalized["data"] = normalize_node_data_for_graph(data_mapping) + return normalized + + +def _normalize_email_recipients(recipients: Mapping[str, Any]) -> dict[str, Any]: + normalized = dict(recipients) + + legacy_include_bound_group = normalized.pop("whole_workspace", None) + if "include_bound_group" not in normalized and legacy_include_bound_group is not None: + normalized["include_bound_group"] = legacy_include_bound_group + + items = normalized.get("items") + if not isinstance(items, list): + return normalized + + normalized_items: list[Any] = [] + for item in items: + item_mapping = _copy_mapping(item) + if item_mapping is None: + normalized_items.append(item) + continue + + legacy_reference_id = item_mapping.pop("user_id", None) + if "reference_id" not in item_mapping and legacy_reference_id is not None: + item_mapping["reference_id"] = legacy_reference_id + normalized_items.append(item_mapping) + + normalized["items"] = normalized_items + return normalized + + +__all__ = [ + "BoundRecipient", + "DeliveryChannelConfig", + "DeliveryMethodType", + "EmailDeliveryConfig", + "EmailDeliveryMethod", + "EmailRecipientType", + "EmailRecipients", + "ExternalRecipient", + "MemberRecipient", + "WebAppDeliveryMethod", + "_WebAppDeliveryConfig", + "is_human_input_webapp_enabled", + "normalize_human_input_node_data_for_graph", + "normalize_node_config_for_graph", + "normalize_node_data_for_graph", + "parse_human_input_delivery_methods", +] diff --git a/api/core/workflow/human_input_forms.py b/api/core/workflow/human_input_forms.py new file mode 100644 index 0000000000..f124b321d4 --- /dev/null +++ b/api/core/workflow/human_input_forms.py @@ -0,0 +1,55 @@ +"""Shared helpers for workflow pause-time human input form lookups. + +Both controllers and streaming response converters need the same recipient +priority when exposing resume links for paused human input forms. Keep that +selection logic here so all API surfaces stay consistent. +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from sqlalchemy import select +from sqlalchemy.orm import Session + +from extensions.ext_database import db +from models.human_input import HumanInputFormRecipient, RecipientType + +_FORM_TOKEN_PRIORITY = { + RecipientType.BACKSTAGE: 0, + RecipientType.CONSOLE: 1, + RecipientType.STANDALONE_WEB_APP: 2, +} + + +def load_form_tokens_by_form_id( + form_ids: Sequence[str], + *, + session: Session | None = None, +) -> dict[str, str]: + """Load the preferred access token for each human input form.""" + unique_form_ids = list(dict.fromkeys(form_ids)) + if not unique_form_ids: + return {} + + if session is not None: + return _load_form_tokens_by_form_id(session, unique_form_ids) + + with Session(bind=db.engine, expire_on_commit=False) as new_session: + return _load_form_tokens_by_form_id(new_session, unique_form_ids) + + +def _load_form_tokens_by_form_id(session: Session, form_ids: Sequence[str]) -> dict[str, str]: + tokens_by_form_id: dict[str, tuple[int, str]] = {} + stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) + for recipient in session.scalars(stmt): + priority = _FORM_TOKEN_PRIORITY.get(recipient.recipient_type) + if priority is None or not recipient.access_token: + continue + + candidate = (priority, recipient.access_token) + current = tokens_by_form_id.get(recipient.form_id) + if current is None or candidate[0] < current[0]: + tokens_by_form_id[recipient.form_id] = candidate + + return {form_id: token for form_id, (_, token) in tokens_by_form_id.items()} diff --git a/api/core/workflow/node_factory.py b/api/core/workflow/node_factory.py index ab34263a79..8cc21d2cd9 100644 --- a/api/core/workflow/node_factory.py +++ b/api/core/workflow/node_factory.py @@ -4,13 +4,29 @@ from collections.abc import Callable, Iterator, Mapping, MutableMapping from functools import lru_cache from typing import TYPE_CHECKING, Any, TypeAlias, cast, final +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.file.file_manager import file_manager +from graphon.graph.graph import NodeFactory +from graphon.model_runtime.memory import PromptMessageMemory +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.base.node import Node +from graphon.nodes.code.code_node import WorkflowCodeExecutor +from graphon.nodes.code.entities import CodeLanguage +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.nodes.document_extractor import UnstructuredApiConfig +from graphon.nodes.http_request import build_http_request_config +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData from sqlalchemy import select from sqlalchemy.orm import Session from typing_extensions import override from configs import dify_config -from core.app.entities.app_invoke_entities import DifyRunContext -from core.app.llm.model_access import build_dify_model_access +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.llm.model_access import build_dify_model_access, fetch_model_config from core.helper.code_executor.code_executor import ( CodeExecutionError, CodeExecutor, @@ -19,45 +35,32 @@ from core.helper.ssrf_proxy import ssrf_proxy from core.memory.token_buffer_memory import TokenBufferMemory from core.model_manager import ModelInstance from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_NODE_TYPES +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.node_runtime import ( + DifyFileReferenceFactory, + DifyHumanInputNodeRuntime, + DifyPreparedLLM, + DifyPromptMessageSerializer, + DifyRetrieverAttachmentLoader, + DifyToolFileManager, + DifyToolNodeRuntime, + build_dify_llm_file_saver, +) from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer from core.workflow.nodes.agent.plugin_strategy_adapter import ( PluginAgentStrategyPresentationProvider, PluginAgentStrategyResolver, ) from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, NodeType, SystemVariableKey -from dify_graph.file.file_manager import file_manager -from dify_graph.graph.graph import NodeFactory -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.nodes.document_extractor import UnstructuredApiConfig -from dify_graph.nodes.http_request import build_http_request_config -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError -from dify_graph.nodes.llm.protocols import TemplateRenderer -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.template_transform.template_renderer import ( - CodeExecutorJinja2TemplateRenderer, -) -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector +from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer from extensions.ext_database import db from models.model import Conversation if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState LATEST_VERSION = "latest" _START_NODE_TYPES: frozenset[NodeType] = frozenset( @@ -76,7 +79,7 @@ def _import_node_package(package_name: str, *, excluded_modules: frozenset[str] @lru_cache(maxsize=1) def register_nodes() -> None: """Import production node modules so they self-register with ``Node``.""" - _import_node_package("dify_graph.nodes") + _import_node_package("graphon.nodes") _import_node_package("core.workflow.nodes") @@ -84,7 +87,7 @@ def get_node_type_classes_mapping() -> Mapping[NodeType, Mapping[str, type[Node] """Return a read-only snapshot of the current production node registry. The workflow layer owns node bootstrap because it must compose built-in - `dify_graph.nodes.*` implementations with workflow-local nodes under + `graphon.nodes.*` implementations with workflow-local nodes under `core.workflow.nodes.*`. Keeping this import side effect here avoids reintroducing registry bootstrapping into lower-level graph primitives. """ @@ -115,7 +118,7 @@ def get_default_root_node_id(graph_config: Mapping[str, Any]) -> str: This workflow-layer helper depends on start-node semantics defined by `is_start_node_type`, so it intentionally lives next to the node registry - instead of in the raw `dify_graph.entities.graph_config` schema module. + instead of in the raw `graphon.entities.graph_config` schema module. """ nodes = graph_config.get("nodes") if not isinstance(nodes, list): @@ -229,16 +232,6 @@ class DefaultWorkflowCodeExecutor: return isinstance(error, CodeExecutionError) -class DefaultLLMTemplateRenderer(TemplateRenderer): - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - result = CodeExecutor.execute_workflow_code_template( - language=CodeLanguage.JINJA2, - code=template, - inputs=inputs, - ) - return str(result.get("result", "")) - - @final class DifyNodeFactory(NodeFactory): """ @@ -264,11 +257,31 @@ class DifyNodeFactory(NodeFactory): max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, ) - self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor) - self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer() + self._jinja2_template_renderer = CodeExecutorJinja2TemplateRenderer() self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH self._http_request_http_client = ssrf_proxy - self._http_request_tool_file_manager_factory = ToolFileManager + self._bound_tool_file_manager_factory = lambda: DifyToolFileManager( + self._dify_context, + conversation_id_getter=self._conversation_id, + ) + self._file_reference_factory = DifyFileReferenceFactory(self._dify_context) + self._prompt_message_serializer = DifyPromptMessageSerializer() + self._retriever_attachment_loader = DifyRetrieverAttachmentLoader( + file_reference_factory=self._file_reference_factory, + ) + self._llm_file_saver = build_dify_llm_file_saver( + run_context=self._dify_context, + http_client=self._http_request_http_client, + conversation_id_getter=self._conversation_id, + ) + self._human_input_runtime = DifyHumanInputNodeRuntime( + self._dify_context, + workflow_execution_id_getter=lambda: get_system_text( + self.graph_runtime_state.variable_pool, + SystemVariableKey.WORKFLOW_EXECUTION_ID, + ), + ) + self._tool_runtime = DifyToolNodeRuntime(self._dify_context) self._http_request_file_manager = file_manager self._document_extractor_unstructured_api_config = UnstructuredApiConfig( api_url=dify_config.UNSTRUCTURED_API_URL, @@ -284,7 +297,7 @@ class DifyNodeFactory(NodeFactory): ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES, ) - self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id) + self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context) self._agent_strategy_resolver = PluginAgentStrategyResolver() self._agent_strategy_presentation_provider = PluginAgentStrategyPresentationProvider() self._agent_runtime_support = AgentRuntimeSupport() @@ -299,6 +312,9 @@ class DifyNodeFactory(NodeFactory): return raw_ctx return DifyRunContext.model_validate(raw_ctx) + def _conversation_id(self) -> str | None: + return get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) + @override def create_node(self, node_config: dict[str, Any] | NodeConfigDict) -> Node: """ @@ -310,7 +326,7 @@ class DifyNodeFactory(NodeFactory): (including pydantic ValidationError, which subclasses ValueError), if node type is unknown, or if no implementation exists for the resolved version """ - typed_node_config = NodeConfigDictAdapter.validate_python(node_config) + typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) node_id = typed_node_config["id"] node_data = typed_node_config["data"] node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version)) @@ -321,22 +337,29 @@ class DifyNodeFactory(NodeFactory): "code_limits": self._code_limits, }, BuiltinNodeTypes.TEMPLATE_TRANSFORM: lambda: { - "template_renderer": self._template_renderer, + "jinja2_template_renderer": self._jinja2_template_renderer, "max_output_length": self._template_transform_max_output_length, }, BuiltinNodeTypes.HTTP_REQUEST: lambda: { "http_request_config": self._http_request_config, "http_client": self._http_request_http_client, - "tool_file_manager_factory": self._http_request_tool_file_manager_factory, + "tool_file_manager_factory": self._bound_tool_file_manager_factory, "file_manager": self._http_request_file_manager, + "file_reference_factory": self._file_reference_factory, }, BuiltinNodeTypes.HUMAN_INPUT: lambda: { - "form_repository": HumanInputFormRepositoryImpl(tenant_id=self._dify_context.tenant_id), + "runtime": self._human_input_runtime, + "form_repository": self._human_input_runtime.build_form_repository(), }, BuiltinNodeTypes.LLM: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=True, + include_jinja2_template_renderer=True, ), BuiltinNodeTypes.DOCUMENT_EXTRACTOR: lambda: { "unstructured_api_config": self._document_extractor_unstructured_api_config, @@ -345,15 +368,26 @@ class DifyNodeFactory(NodeFactory): BuiltinNodeTypes.QUESTION_CLASSIFIER: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=True, + include_llm_file_saver=True, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.PARAMETER_EXTRACTOR: lambda: self._build_llm_compatible_node_init_kwargs( node_class=node_class, node_data=node_data, + wrap_model_instance=True, include_http_client=False, + include_llm_file_saver=False, + include_prompt_message_serializer=True, + include_retriever_attachment_loader=False, + include_jinja2_template_renderer=False, ), BuiltinNodeTypes.TOOL: lambda: { - "tool_file_manager_factory": self._http_request_tool_file_manager_factory(), + "tool_file_manager_factory": self._bound_tool_file_manager_factory(), + "runtime": self._tool_runtime, }, BuiltinNodeTypes.AGENT: lambda: { "strategy_resolver": self._agent_strategy_resolver, @@ -387,7 +421,12 @@ class DifyNodeFactory(NodeFactory): *, node_class: type[Node], node_data: BaseNodeData, + wrap_model_instance: bool, include_http_client: bool, + include_llm_file_saver: bool, + include_prompt_message_serializer: bool, + include_retriever_attachment_loader: bool, + include_jinja2_template_renderer: bool, ) -> dict[str, object]: validated_node_data = cast( LLMCompatibleNodeData, @@ -397,49 +436,35 @@ class DifyNodeFactory(NodeFactory): node_init_kwargs: dict[str, object] = { "credentials_provider": self._llm_credentials_provider, "model_factory": self._llm_model_factory, - "model_instance": model_instance, + "model_instance": DifyPreparedLLM(model_instance) if wrap_model_instance else model_instance, "memory": self._build_memory_for_llm_node( node_data=validated_node_data, model_instance=model_instance, ), } - if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}: - node_init_kwargs["template_renderer"] = self._llm_template_renderer + if validated_node_data.type == BuiltinNodeTypes.QUESTION_CLASSIFIER: + node_init_kwargs["template_renderer"] = self._jinja2_template_renderer if include_http_client: node_init_kwargs["http_client"] = self._http_request_http_client + if include_llm_file_saver: + node_init_kwargs["llm_file_saver"] = self._llm_file_saver + if include_prompt_message_serializer: + node_init_kwargs["prompt_message_serializer"] = self._prompt_message_serializer + if include_retriever_attachment_loader: + node_init_kwargs["retriever_attachment_loader"] = self._retriever_attachment_loader + if include_jinja2_template_renderer: + node_init_kwargs["jinja2_template_renderer"] = self._jinja2_template_renderer + if validated_node_data.type == BuiltinNodeTypes.LLM: + node_init_kwargs["default_query_selector"] = system_variable_selector(SystemVariableKey.QUERY) return node_init_kwargs def _build_model_instance_for_llm_node(self, node_data: LLMCompatibleNodeData) -> ModelInstance: node_data_model = node_data.model - if not node_data_model.mode: - raise LLMModeRequiredError("LLM mode is required.") - - credentials = self._llm_credentials_provider.fetch(node_data_model.provider, node_data_model.name) - model_instance = self._llm_model_factory.init_model_instance(node_data_model.provider, node_data_model.name) - provider_model_bundle = model_instance.provider_model_bundle - - provider_model = provider_model_bundle.configuration.get_provider_model( - model=node_data_model.name, - model_type=ModelType.LLM, + model_instance, _ = fetch_model_config( + node_data_model=node_data_model, + credentials_provider=self._llm_credentials_provider, + model_factory=self._llm_model_factory, ) - if provider_model is None: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - provider_model.raise_for_status() - - completion_params = dict(node_data_model.completion_params) - stop = completion_params.pop("stop", []) - if not isinstance(stop, list): - stop = [] - - model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials) - if not model_schema: - raise ModelNotExistError(f"Model {node_data_model.name} not exist.") - - model_instance.provider = node_data_model.provider - model_instance.model_name = node_data_model.name - model_instance.credentials = credentials - model_instance.parameters = completion_params - model_instance.stop = tuple(stop) model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) return model_instance @@ -452,12 +477,7 @@ class DifyNodeFactory(NodeFactory): if node_data.memory is None: return None - conversation_id_variable = self.graph_runtime_state.variable_pool.get( - ["sys", SystemVariableKey.CONVERSATION_ID] - ) - conversation_id = ( - conversation_id_variable.value if isinstance(conversation_id_variable, StringSegment) else None - ) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) return fetch_memory( conversation_id=conversation_id, app_id=self._dify_context.app_id, diff --git a/api/core/workflow/node_runtime.py b/api/core/workflow/node_runtime.py new file mode 100644 index 0000000000..19cb3a7b0a --- /dev/null +++ b/api/core/workflow/node_runtime.py @@ -0,0 +1,671 @@ +from __future__ import annotations + +from collections.abc import Callable, Generator, Mapping, Sequence +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, cast + +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities import LLMMode +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import AIModelEntity +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.nodes.llm.runtime_protocols import ( + PreparedLLMProtocol, + PromptMessageSerializerProtocol, + RetrieverAttachmentLoaderProtocol, +) +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.runtime import ( + HumanInputFormStateProtocol, + HumanInputNodeRuntimeProtocol, + ToolNodeRuntimeProtocol, +) +from graphon.nodes.tool.exc import ToolNodeError, ToolRuntimeInvocationError, ToolRuntimeResolutionError +from graphon.nodes.tool_runtime_entities import ( + ToolRuntimeHandle, + ToolRuntimeMessage, + ToolRuntimeParameter, +) +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.app.file_access import DatabaseFileAccessController +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output +from core.model_manager import ModelInstance +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.plugin.impl.plugin import PluginInstaller +from core.prompt.utils.prompt_message_util import PromptMessageUtil +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormRepository, + HumanInputFormRepositoryImpl, +) +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_file_manager import ToolFileManager +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from factories import file_factory +from models.dataset import SegmentAttachmentBinding +from models.model import UploadFile +from services.tools.builtin_tools_manage_service import BuiltinToolManageService + +from .human_input_compat import ( + BoundRecipient, + DeliveryChannelConfig, + DeliveryMethodType, + EmailDeliveryMethod, + EmailRecipients, + is_human_input_webapp_enabled, + parse_human_input_delivery_methods, +) +from .system_variables import SystemVariableKey, get_system_text + +if TYPE_CHECKING: + from graphon.file import File + from graphon.nodes.llm.file_saver import LLMFileSaver + from graphon.nodes.tool.entities import ToolNodeData + + from core.tools.__base.tool import Tool + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + + +_file_access_controller = DatabaseFileAccessController() + + +def resolve_dify_run_context(run_context: Mapping[str, Any] | DifyRunContext) -> DifyRunContext: + if isinstance(run_context, DifyRunContext): + return run_context + + raw_ctx = run_context.get(DIFY_RUN_CONTEXT_KEY) + if raw_ctx is None: + raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") + if isinstance(raw_ctx, DifyRunContext): + return raw_ctx + return DifyRunContext.model_validate(raw_ctx) + + +def apply_dify_debug_email_recipient( + method: DeliveryChannelConfig, + *, + enabled: bool, + actor_id: str | None, +) -> DeliveryChannelConfig: + """Apply the Dify debugger-specific email recipient override outside `graphon`.""" + if not enabled: + return method + if not isinstance(method, EmailDeliveryMethod): + return method + if not method.config.debug_mode: + return method + + if actor_id is None: + debug_recipients = EmailRecipients(include_bound_group=False, items=[]) + else: + debug_recipients = EmailRecipients( + include_bound_group=False, + items=[BoundRecipient(reference_id=actor_id)], + ) + debug_config = method.config.with_recipients(debug_recipients) + return method.model_copy(update={"config": debug_config}) + + +class DifyFileReferenceFactory(FileReferenceFactoryProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + + def build_from_mapping(self, *, mapping: Mapping[str, Any]): + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=self._run_context.tenant_id, + access_controller=_file_access_controller, + ) + + +class DifyPreparedLLM(PreparedLLMProtocol): + """Workflow-layer adapter that hides the full `ModelInstance` API from `graphon` nodes.""" + + def __init__(self, model_instance: ModelInstance) -> None: + self._model_instance = model_instance + + @property + def provider(self) -> str: + return self._model_instance.provider + + @property + def model_name(self) -> str: + return self._model_instance.model_name + + @property + def parameters(self) -> Mapping[str, Any]: + return self._model_instance.parameters + + @parameters.setter + def parameters(self, value: Mapping[str, Any]) -> None: + self._model_instance.parameters = value + + @property + def stop(self) -> Sequence[str] | None: + return self._model_instance.stop + + def get_model_schema(self) -> AIModelEntity: + model_schema = cast(LargeLanguageModel, self._model_instance.model_type_instance).get_model_schema( + self._model_instance.model_name, + self._model_instance.credentials, + ) + if model_schema is None: + raise ValueError(f"Model schema not found for {self._model_instance.model_name}") + return model_schema + + def get_llm_num_tokens(self, prompt_messages: Sequence[PromptMessage]) -> int: + return self._model_instance.get_llm_num_tokens(prompt_messages) + + def invoke_llm( + self, + *, + prompt_messages: Sequence[PromptMessage], + model_parameters: Mapping[str, Any], + tools: Sequence[PromptMessageTool] | None, + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResult | Generator[LLMResultChunk, None, None]: + return self._model_instance.invoke_llm( + prompt_messages=list(prompt_messages), + model_parameters=dict(model_parameters), + tools=list(tools or []), + stop=list(stop or []), + stream=stream, + ) + + def invoke_llm_with_structured_output( + self, + *, + prompt_messages: Sequence[PromptMessage], + json_schema: Mapping[str, Any], + model_parameters: Mapping[str, Any], + stop: Sequence[str] | None, + stream: bool, + ) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: + return invoke_llm_with_structured_output( + provider=self.provider, + model_schema=self.get_model_schema(), + model_instance=self._model_instance, + prompt_messages=prompt_messages, + json_schema=json_schema, + model_parameters=model_parameters, + stop=list(stop or []), + stream=stream, + ) + + def is_structured_output_parse_error(self, error: Exception) -> bool: + return isinstance(error, OutputParserError) + + +class DifyPromptMessageSerializer(PromptMessageSerializerProtocol): + def serialize( + self, + *, + model_mode: LLMMode, + prompt_messages: Sequence[PromptMessage], + ) -> Any: + return PromptMessageUtil.prompt_messages_to_prompt_for_saving( + model_mode=model_mode, + prompt_messages=prompt_messages, + ) + + +class DifyRetrieverAttachmentLoader(RetrieverAttachmentLoaderProtocol): + """Resolve retriever attachments through Dify persistence and return graph file references.""" + + def __init__(self, *, file_reference_factory: FileReferenceFactoryProtocol) -> None: + self._file_reference_factory = file_reference_factory + + def load(self, *, segment_id: str) -> Sequence[File]: + with Session(db.engine, expire_on_commit=False) as session: + attachments_with_bindings = session.execute( + select(SegmentAttachmentBinding, UploadFile) + .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) + .where(SegmentAttachmentBinding.segment_id == segment_id) + ).all() + + return [ + self._file_reference_factory.build_from_mapping( + mapping={ + "id": upload_file.id, + "filename": upload_file.name, + "extension": "." + upload_file.extension, + "mime_type": upload_file.mime_type, + "type": FileType.IMAGE, + "transfer_method": FileTransferMethod.LOCAL_FILE, + "remote_url": upload_file.source_url, + "reference": build_file_reference(record_id=str(upload_file.id)), + "size": upload_file.size, + } + ) + for _, upload_file in attachments_with_bindings + ] + + +class DifyToolFileManager(ToolFileManagerProtocol): + """Workflow adapter that resolves conversation scope outside `graphon`.""" + + _conversation_id_getter: Callable[[], str | None] | None + + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + conversation_id_getter: Callable[[], str | None] | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._manager = ToolFileManager() + self._conversation_id_getter = conversation_id_getter + + def create_file_by_raw( + self, + *, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ) -> Any: + conversation_id = self._conversation_id_getter() if self._conversation_id_getter is not None else None + return self._manager.create_file_by_raw( + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=conversation_id, + file_binary=file_binary, + mimetype=mimetype, + filename=filename, + ) + + def get_file_generator_by_tool_file_id(self, tool_file_id: str): + return self._manager.get_file_generator_by_tool_file_id(tool_file_id) + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeSpec: + provider_type: CoreToolProviderType + provider_id: str + tool_name: str + tool_configurations: dict[str, Any] + credential_id: str | None = None + + +@dataclass(frozen=True, slots=True) +class _WorkflowToolRuntimeBinding: + """Workflow-private runtime state stored inside the opaque graph handle. + + The binding keeps conversation scope in `core.workflow` while `graphon` + continues to treat the handle as an opaque token. + """ + + tool: Tool + conversation_id: str | None = None + + +class DifyToolNodeRuntime(ToolNodeRuntimeProtocol): + def __init__(self, run_context: Mapping[str, Any] | DifyRunContext) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._file_reference_factory = DifyFileReferenceFactory(self._run_context) + + @property + def file_reference_factory(self) -> FileReferenceFactoryProtocol: + return self._file_reference_factory + + def build_file_reference(self, *, mapping: Mapping[str, Any]): + return self._file_reference_factory.build_from_mapping(mapping=mapping) + + def get_runtime( + self, + *, + node_id: str, + node_data: ToolNodeData, + variable_pool, + ) -> ToolRuntimeHandle: + try: + tool_runtime = ToolManager.get_workflow_tool_runtime( + self._run_context.tenant_id, + self._run_context.app_id, + node_id, + self._build_tool_runtime_spec(node_data), + self._run_context.user_id, + self._run_context.invoke_from, + variable_pool, + ) + except ToolNodeError: + raise + except Exception as exc: + raise ToolRuntimeResolutionError(str(exc)) from exc + + conversation_id = ( + None if variable_pool is None else get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + ) + return ToolRuntimeHandle(raw=_WorkflowToolRuntimeBinding(tool=tool_runtime, conversation_id=conversation_id)) + + def get_runtime_parameters( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> Sequence[ToolRuntimeParameter]: + tool = self._tool_from_handle(tool_runtime) + return [ + ToolRuntimeParameter(name=parameter.name, required=parameter.required) + for parameter in (tool.get_merged_runtime_parameters() or []) + ] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: Mapping[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + runtime_binding = self._binding_from_handle(tool_runtime) + tool = runtime_binding.tool + callback = DifyWorkflowCallbackHandler() + + try: + messages = ToolEngine.generic_invoke( + tool=tool, + tool_parameters=dict(tool_parameters), + user_id=self._run_context.user_id, + workflow_tool_callback=callback, + workflow_call_depth=workflow_call_depth, + app_id=self._run_context.app_id, + conversation_id=runtime_binding.conversation_id, + ) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + transformed_messages = ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=messages, + user_id=self._run_context.user_id, + tenant_id=self._run_context.tenant_id, + conversation_id=runtime_binding.conversation_id, + ) + + return self._adapt_messages(transformed_messages, provider_name=provider_name) + + def get_usage( + self, + *, + tool_runtime: ToolRuntimeHandle, + ) -> LLMUsage: + latest = getattr(self._binding_from_handle(tool_runtime).tool, "latest_usage", None) + if isinstance(latest, LLMUsage): + return latest + if isinstance(latest, dict): + return LLMUsage.model_validate(latest) + return LLMUsage.empty_usage() + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | Mapping[str, str] | None, str | Mapping[str, str] | None]: + icon: str | Mapping[str, str] | None = default_icon + icon_dark: str | Mapping[str, str] | None = None + + manager = PluginInstaller() + plugins = manager.list_plugins(self._run_context.tenant_id) + try: + current_plugin = next(plugin for plugin in plugins if f"{plugin.plugin_id}/{plugin.name}" == provider_name) + icon = current_plugin.declaration.icon + except StopIteration: + pass + + try: + builtin_tool = next( + provider + for provider in BuiltinToolManageService.list_builtin_tools( + self._run_context.user_id, + self._run_context.tenant_id, + ) + if provider.name == provider_name + ) + icon = builtin_tool.icon + icon_dark = builtin_tool.icon_dark + except StopIteration: + pass + + return icon, icon_dark + + @staticmethod + def _tool_from_handle(tool_runtime: ToolRuntimeHandle) -> Tool: + return DifyToolNodeRuntime._binding_from_handle(tool_runtime).tool + + @staticmethod + def _binding_from_handle(tool_runtime: ToolRuntimeHandle) -> _WorkflowToolRuntimeBinding: + if isinstance(tool_runtime.raw, _WorkflowToolRuntimeBinding): + return tool_runtime.raw + return _WorkflowToolRuntimeBinding(tool=cast("Tool", tool_runtime.raw)) + + @staticmethod + def _build_tool_runtime_spec(node_data: ToolNodeData) -> _WorkflowToolRuntimeSpec: + return _WorkflowToolRuntimeSpec( + provider_type=CoreToolProviderType(node_data.provider_type.value), + provider_id=node_data.provider_id, + tool_name=node_data.tool_name, + tool_configurations=dict(node_data.tool_configurations), + credential_id=node_data.credential_id, + ) + + def _adapt_messages( + self, + messages: Generator[CoreToolInvokeMessage, None, None], + *, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + try: + for message in messages: + yield self._convert_message(message) + except Exception as exc: + raise self._map_invocation_exception(exc, provider_name=provider_name) from exc + + def _convert_message(self, message: CoreToolInvokeMessage) -> ToolRuntimeMessage: + graph_message_type = ToolRuntimeMessage.MessageType(message.type.value) + graph_message = self._convert_message_payload(message.message) + graph_meta = message.meta.copy() if message.meta is not None else None + return ToolRuntimeMessage(type=graph_message_type, message=graph_message, meta=graph_meta) + + def _convert_message_payload( + self, + message: CoreToolInvokeMessage.TextMessage + | CoreToolInvokeMessage.JsonMessage + | CoreToolInvokeMessage.BlobChunkMessage + | CoreToolInvokeMessage.BlobMessage + | CoreToolInvokeMessage.LogMessage + | CoreToolInvokeMessage.FileMessage + | CoreToolInvokeMessage.VariableMessage + | CoreToolInvokeMessage.RetrieverResourceMessage + | None, + ) -> ( + ToolRuntimeMessage.TextMessage + | ToolRuntimeMessage.JsonMessage + | ToolRuntimeMessage.BlobChunkMessage + | ToolRuntimeMessage.BlobMessage + | ToolRuntimeMessage.LogMessage + | ToolRuntimeMessage.FileMessage + | ToolRuntimeMessage.VariableMessage + | ToolRuntimeMessage.RetrieverResourceMessage + | None + ): + if message is None: + return None + + from core.tools.entities.tool_entities import ToolInvokeMessage as CoreToolInvokeMessage + + if isinstance(message, CoreToolInvokeMessage.TextMessage): + return ToolRuntimeMessage.TextMessage(text=message.text) + if isinstance(message, CoreToolInvokeMessage.JsonMessage): + return ToolRuntimeMessage.JsonMessage( + json_object=message.json_object, + suppress_output=message.suppress_output, + ) + if isinstance(message, CoreToolInvokeMessage.BlobMessage): + return ToolRuntimeMessage.BlobMessage(blob=message.blob) + if isinstance(message, CoreToolInvokeMessage.BlobChunkMessage): + return ToolRuntimeMessage.BlobChunkMessage( + id=message.id, + sequence=message.sequence, + total_length=message.total_length, + blob=message.blob, + end=message.end, + ) + if isinstance(message, CoreToolInvokeMessage.FileMessage): + return ToolRuntimeMessage.FileMessage(file_marker=message.file_marker) + if isinstance(message, CoreToolInvokeMessage.VariableMessage): + return ToolRuntimeMessage.VariableMessage( + variable_name=message.variable_name, + variable_value=message.variable_value, + stream=message.stream, + ) + if isinstance(message, CoreToolInvokeMessage.LogMessage): + return ToolRuntimeMessage.LogMessage( + id=message.id, + label=message.label, + parent_id=message.parent_id, + error=message.error, + status=ToolRuntimeMessage.LogMessage.LogStatus(message.status.value), + data=dict(message.data), + metadata=dict(message.metadata), + ) + if isinstance(message, CoreToolInvokeMessage.RetrieverResourceMessage): + retriever_resources = [ + resource.model_dump() if hasattr(resource, "model_dump") else dict(resource) + for resource in message.retriever_resources + ] + return ToolRuntimeMessage.RetrieverResourceMessage( + retriever_resources=retriever_resources, + context=message.context, + ) + + raise TypeError(f"unsupported tool message payload: {type(message).__name__}") + + @staticmethod + def _map_invocation_exception(exc: Exception, *, provider_name: str) -> ToolNodeError: + if isinstance(exc, ToolNodeError): + return exc + if isinstance(exc, PluginInvokeError): + return ToolRuntimeInvocationError(exc.to_user_friendly_error(plugin_name=provider_name)) + if isinstance(exc, PluginDaemonClientSideError): + return ToolRuntimeInvocationError(f"Failed to invoke tool, error: {exc.description}") + if isinstance(exc, ToolInvokeError): + return ToolRuntimeInvocationError(f"Failed to invoke tool {provider_name}: {exc}") + return ToolRuntimeInvocationError(str(exc)) + + +class DifyHumanInputNodeRuntime(HumanInputNodeRuntimeProtocol): + def __init__( + self, + run_context: Mapping[str, Any] | DifyRunContext, + *, + workflow_execution_id_getter: Callable[[], str | None] | None = None, + form_repository: HumanInputFormRepository | None = None, + ) -> None: + self._run_context = resolve_dify_run_context(run_context) + self._workflow_execution_id_getter = workflow_execution_id_getter + self._form_repository = form_repository + + def _invoke_source(self) -> str: + invoke_from = self._run_context.invoke_from + if isinstance(invoke_from, str): + return invoke_from + return str(getattr(invoke_from, "value", invoke_from)) + + def _resolve_delivery_methods(self, *, node_data: HumanInputNodeData) -> Sequence[DeliveryChannelConfig]: + invoke_source = self._invoke_source() + methods = [method for method in parse_human_input_delivery_methods(node_data) if method.enabled] + if invoke_source in {"debugger", "explore"}: + methods = [method for method in methods if method.type != DeliveryMethodType.WEBAPP] + return [ + apply_dify_debug_email_recipient( + method, + enabled=invoke_source == "debugger", + actor_id=self._run_context.user_id, + ) + for method in methods + ] + + def _display_in_ui(self, *, node_data: HumanInputNodeData) -> bool: + if self._invoke_source() == "debugger": + return True + return is_human_input_webapp_enabled(node_data) + + def build_form_repository(self) -> HumanInputFormRepository: + if self._form_repository is not None: + return self._form_repository + + return self._build_form_repository() + + def _build_form_repository(self) -> HumanInputFormRepository: + invoke_source = self._invoke_source() + return HumanInputFormRepositoryImpl( + tenant_id=self._run_context.tenant_id, + app_id=self._run_context.app_id, + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + invoke_source=invoke_source, + submission_actor_id=self._run_context.user_id if invoke_source in {"debugger", "explore"} else None, + ) + + def with_form_repository(self, form_repository: HumanInputFormRepository) -> DifyHumanInputNodeRuntime: + return DifyHumanInputNodeRuntime( + self._run_context, + workflow_execution_id_getter=self._workflow_execution_id_getter, + form_repository=form_repository, + ) + + def get_form(self, *, node_id: str) -> HumanInputFormStateProtocol | None: + repo = self.build_form_repository() + return repo.get_form(node_id) + + def create_form( + self, + *, + node_id: str, + node_data: HumanInputNodeData, + rendered_content: str, + resolved_default_values: Mapping[str, Any], + ) -> HumanInputFormStateProtocol: + repo = self.build_form_repository() + params = FormCreateParams( + workflow_execution_id=self._workflow_execution_id_getter() if self._workflow_execution_id_getter else None, + node_id=node_id, + form_config=node_data, + rendered_content=rendered_content, + delivery_methods=self._resolve_delivery_methods(node_data=node_data), + display_in_ui=self._display_in_ui(node_data=node_data), + resolved_default_values=resolved_default_values, + ) + return repo.create_form(params) + + +def build_dify_llm_file_saver( + *, + run_context: Mapping[str, Any] | DifyRunContext, + http_client: HttpClientProtocol, + conversation_id_getter: Callable[[], str | None] | None = None, +) -> LLMFileSaver: + from graphon.nodes.llm.file_saver import FileSaverImpl + + return FileSaverImpl( + tool_file_manager=DifyToolFileManager(run_context, conversation_id_getter=conversation_id_getter), + file_reference_factory=DifyFileReferenceFactory(run_context), + http_client=http_client, + ) diff --git a/api/core/workflow/nodes/agent/agent_node.py b/api/core/workflow/nodes/agent/agent_node.py index 5699ccf404..bfd5536e4a 100644 --- a/api/core/workflow/nodes/agent/agent_node.py +++ b/api/core/workflow/nodes/agent/agent_node.py @@ -3,11 +3,14 @@ from __future__ import annotations from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.node_events import NodeEventBase, NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.workflow.system_variables import SystemVariableKey, get_system_text from .entities import AgentNodeData from .exceptions import ( @@ -19,8 +22,8 @@ from .runtime_support import AgentRuntimeSupport from .strategy_protocols import AgentStrategyPresentationProvider, AgentStrategyResolver if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class AgentNode(Node[AgentNodeData]): @@ -59,7 +62,7 @@ class AgentNode(Node[AgentNodeData]): return "1" def populate_start_event(self, event) -> None: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) event.extras["agent_strategy"] = { "name": self.node_data.agent_strategy_name, "icon": self._presentation_provider.get_icon( @@ -71,7 +74,7 @@ class AgentNode(Node[AgentNodeData]): def _run(self) -> Generator[NodeEventBase, None, None]: from core.plugin.impl.exc import PluginDaemonClientSideError - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) try: strategy = self._strategy_resolver.resolve( @@ -97,6 +100,7 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, ) @@ -106,20 +110,21 @@ class AgentNode(Node[AgentNodeData]): node_data=self.node_data, strategy=strategy, tenant_id=dify_ctx.tenant_id, + user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, invoke_from=dify_ctx.invoke_from, for_log=True, ) credentials = self._runtime_support.build_credentials(parameters=parameters) - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) + conversation_id = get_system_text(self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID) try: message_stream = strategy.invoke( params=parameters, user_id=dify_ctx.user_id, app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, + conversation_id=conversation_id, credentials=credentials, ) except Exception as e: @@ -146,6 +151,7 @@ class AgentNode(Node[AgentNodeData]): parameters_for_log=parameters_for_log, user_id=dify_ctx.user_id, tenant_id=dify_ctx.tenant_id, + conversation_id=conversation_id, node_type=self.node_type, node_id=self._node_id, node_execution_id=self.id, diff --git a/api/core/workflow/nodes/agent/entities.py b/api/core/workflow/nodes/agent/entities.py index 91fed39795..c52aad150b 100644 --- a/api/core/workflow/nodes/agent/entities.py +++ b/api/core/workflow/nodes/agent/entities.py @@ -1,12 +1,12 @@ from enum import IntEnum, StrEnum, auto from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel from core.prompt.entities.advanced_prompt_entities import MemoryConfig from core.tools.entities.tool_entities import ToolSelector -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType class AgentNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/agent/message_transformer.py b/api/core/workflow/nodes/agent/message_transformer.py index f58a5665f4..db74590ed7 100644 --- a/api/core/workflow/nodes/agent/message_transformer.py +++ b/api/core/workflow/nodes/agent/message_transformer.py @@ -3,23 +3,24 @@ from __future__ import annotations from collections.abc import Generator, Mapping from typing import Any, cast -from sqlalchemy import select -from sqlalchemy.orm import Session - -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, get_file_type_by_mime_type +from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import ( AgentLogEvent, NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent, ) -from dify_graph.variables.segments import ArrayFileSegment +from graphon.variables.segments import ArrayFileSegment +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import DatabaseFileAccessController +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.utils.message_transformer import ToolFileMessageTransformer from extensions.ext_database import db from factories import file_factory from models import ToolFile @@ -27,6 +28,8 @@ from services.tools.builtin_tools_manage_service import BuiltinToolManageService from .exceptions import AgentNodeError, AgentVariableTypeError, ToolFileNotFoundError +_file_access_controller = DatabaseFileAccessController() + class AgentMessageTransformer: def transform( @@ -37,6 +40,7 @@ class AgentMessageTransformer: parameters_for_log: dict[str, Any], user_id: str, tenant_id: str, + conversation_id: str | None, node_type: NodeType, node_id: str, node_execution_id: str, @@ -47,7 +51,7 @@ class AgentMessageTransformer: messages=messages, user_id=user_id, tenant_id=tenant_id, - conversation_id=None, + conversation_id=conversation_id, ) text = "" @@ -70,10 +74,12 @@ class AgentMessageTransformer: url = message.message.text if message.meta: transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) + tool_file_id = message.meta.get("tool_file_id") else: transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] + tool_file_id = None + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) @@ -83,20 +89,23 @@ class AgentMessageTransformer: mapping = { "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), + "type": get_file_type_by_mime_type(tool_file.mimetype), "transfer_method": transfer_method, "url": url, } file = file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) files.append(file) elif message.type == ToolInvokeMessage.MessageType.BLOB: assert isinstance(message.message, ToolInvokeMessage.TextMessage) assert message.meta - tool_file_id = message.message.text.split("/")[-1].split(".")[0] + tool_file_id = message.meta.get("tool_file_id") + if not isinstance(tool_file_id, str) or not tool_file_id: + raise ToolFileNotFoundError("missing tool_file_id metadata") with Session(db.engine) as session: stmt = select(ToolFile).where(ToolFile.id == tool_file_id) tool_file = session.scalar(stmt) @@ -111,6 +120,7 @@ class AgentMessageTransformer: file_factory.build_from_mapping( mapping=mapping, tenant_id=tenant_id, + access_controller=_file_access_controller, ) ) elif message.type == ToolInvokeMessage.MessageType.TEXT: diff --git a/api/core/workflow/nodes/agent/runtime_support.py b/api/core/workflow/nodes/agent/runtime_support.py index 2ff7c964b9..be50edbc4d 100644 --- a/api/core/workflow/nodes/agent/runtime_support.py +++ b/api/core/workflow/nodes/agent/runtime_support.py @@ -4,6 +4,8 @@ import json from collections.abc import Sequence from typing import Any, cast +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.runtime import VariablePool from packaging.version import Version from pydantic import ValidationError from sqlalchemy import select @@ -12,15 +14,12 @@ from sqlalchemy.orm import Session from core.agent.entities import AgentToolEntity from core.agent.plugin_entities import AgentStrategyParameter from core.memory.token_buffer_memory import TokenBufferMemory -from core.model_manager import ModelInstance, ModelManager +from core.model_manager import ModelInstance from core.plugin.entities.request import InvokeCredentials -from core.provider_manager import ProviderManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly from core.tools.entities.tool_entities import ToolIdentity, ToolParameter, ToolProviderType from core.tools.tool_manager import ToolManager -from dify_graph.enums import SystemVariableKey -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, get_system_text from extensions.ext_database import db from models.model import Conversation @@ -38,6 +37,7 @@ class AgentRuntimeSupport: node_data: AgentNodeData, strategy: ResolvedAgentStrategy, tenant_id: str, + user_id: str, app_id: str, invoke_from: Any, for_log: bool = False, @@ -141,6 +141,7 @@ class AgentRuntimeSupport: tenant_id, app_id, entity, + user_id, invoke_from, runtime_variable_pool, ) @@ -174,7 +175,11 @@ class AgentRuntimeSupport: value = tool_value if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR: value = cast(dict[str, Any], value) - model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value) + model_instance, model_schema = self.fetch_model( + tenant_id=tenant_id, + user_id=user_id, + value=value, + ) history_prompt_messages = [] if node_data.memory: memory = self.fetch_memory( @@ -219,10 +224,9 @@ class AgentRuntimeSupport: app_id: str, model_instance: ModelInstance, ) -> TokenBufferMemory | None: - conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - if not isinstance(conversation_id_variable, StringSegment): + conversation_id = get_system_text(variable_pool, SystemVariableKey.CONVERSATION_ID) + if conversation_id is None: return None - conversation_id = conversation_id_variable.value with Session(db.engine, expire_on_commit=False) as session: stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id) @@ -232,9 +236,15 @@ class AgentRuntimeSupport: return TokenBufferMemory(conversation=conversation, model_instance=model_instance) - def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]: - provider_manager = ProviderManager() - provider_model_bundle = provider_manager.get_provider_model_bundle( + def fetch_model( + self, + *, + tenant_id: str, + user_id: str, + value: dict[str, Any], + ) -> tuple[ModelInstance, AIModelEntity | None]: + assembly = create_plugin_model_assembly(tenant_id=tenant_id, user_id=user_id) + provider_model_bundle = assembly.provider_manager.get_provider_model_bundle( tenant_id=tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM, @@ -246,7 +256,7 @@ class AgentRuntimeSupport: ) provider_name = provider_model_bundle.configuration.provider.provider model_type_instance = provider_model_bundle.model_type_instance - model_instance = ModelManager().get_model_instance( + model_instance = assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider_name, model_type=ModelType(value.get("model_type", "")), diff --git a/api/core/workflow/nodes/datasource/datasource_node.py b/api/core/workflow/nodes/datasource/datasource_node.py index 44f4a23a5a..d9247b2593 100644 --- a/api/core/workflow/nodes/datasource/datasource_node.py +++ b/api/core/workflow/nodes/datasource/datasource_node.py @@ -1,22 +1,30 @@ from collections.abc import Generator, Mapping, Sequence from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( + BuiltinNodeTypes, + NodeExecutionType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.node_events import NodeRunResult, StreamCompletedEvent +from graphon.nodes.base.node import Node +from graphon.nodes.base.variable_template_parser import VariableTemplateParser + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceProviderType from core.plugin.impl.exc import PluginDaemonClientSideError -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, SystemVariableKey, WorkflowNodeExecutionMetadataKey -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.system_variables import SystemVariableKey, get_system_segment from .entities import DatasourceNodeData, DatasourceParameter, OnlineDriveDownloadFileParam from .exc import DatasourceNodeError if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState class DatasourceNode(Node[DatasourceNodeData]): @@ -50,15 +58,14 @@ class DatasourceNode(Node[DatasourceNodeData]): """ Run the datasource node """ - - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) node_data = self.node_data variable_pool = self.graph_runtime_state.variable_pool - datasource_type_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_TYPE]) + datasource_type_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_TYPE) if not datasource_type_segment: raise DatasourceNodeError("Datasource type is not set") datasource_type = str(datasource_type_segment.value) if datasource_type_segment.value else None - datasource_info_segment = variable_pool.get(["sys", SystemVariableKey.DATASOURCE_INFO]) + datasource_info_segment = get_system_segment(variable_pool, SystemVariableKey.DATASOURCE_INFO) if not datasource_info_segment: raise DatasourceNodeError("Datasource info is not set") datasource_info_value = datasource_info_segment.value @@ -131,12 +138,14 @@ class DatasourceNode(Node[DatasourceNodeData]): ) ) case DatasourceProviderType.LOCAL_FILE: - related_id = datasource_info.get("related_id") - if not related_id: + file_id = resolve_file_record_id( + datasource_info.get("reference") or datasource_info.get("related_id") + ) + if not file_id: raise DatasourceNodeError("File is not exist") file_info = self.datasource_manager.get_upload_file_by_id( - file_id=related_id, tenant_id=dify_ctx.tenant_id + file_id=file_id, tenant_id=dify_ctx.tenant_id ) variable_pool.add([self._node_id, "file"], file_info) # variable_pool.add([self.node_id, "file"], file_info.to_dict()) diff --git a/api/core/workflow/nodes/datasource/entities.py b/api/core/workflow/nodes/datasource/entities.py index 65864474b0..cad32f8d5b 100644 --- a/api/core/workflow/nodes/datasource/entities.py +++ b/api/core/workflow/nodes/datasource/entities.py @@ -1,11 +1,10 @@ from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType from pydantic import BaseModel, field_validator from pydantic_core.core_schema import ValidationInfo -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - class DatasourceEntity(BaseModel): plugin_id: str diff --git a/api/core/workflow/nodes/datasource/protocols.py b/api/core/workflow/nodes/datasource/protocols.py index c006e0885c..776e267317 100644 --- a/api/core/workflow/nodes/datasource/protocols.py +++ b/api/core/workflow/nodes/datasource/protocols.py @@ -1,8 +1,8 @@ from collections.abc import Generator from typing import Any, Protocol -from dify_graph.file import File -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.file import File +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from .entities import DatasourceParameter, OnlineDriveDownloadFileParam diff --git a/api/core/workflow/nodes/knowledge_index/entities.py b/api/core/workflow/nodes/knowledge_index/entities.py index 8d2e9bf3cb..cba6c12dca 100644 --- a/api/core/workflow/nodes/knowledge_index/entities.py +++ b/api/core/workflow/nodes/knowledge_index/entities.py @@ -1,12 +1,12 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType class RerankingModelConfig(BaseModel): diff --git a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py index 4ea9091c5b..bb72fe3881 100644 --- a/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py +++ b/api/core/workflow/nodes/knowledge_index/knowledge_index_node.py @@ -2,16 +2,17 @@ import logging from collections.abc import Mapping from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.base.template import Template + from core.rag.index_processor.index_processor import IndexProcessor from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.summary_index.summary_index import SummaryIndex from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, SystemVariableKey -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template +from core.workflow.system_variables import SystemVariableKey, get_system_segment, get_system_text from .entities import KnowledgeIndexNodeData from .exc import ( @@ -19,8 +20,8 @@ from .exc import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) _INVOKE_FROM_DEBUGGER = "debugger" @@ -46,21 +47,20 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): variable_pool = self.graph_runtime_state.variable_pool # get dataset id as string - dataset_id_segment = variable_pool.get(["sys", SystemVariableKey.DATASET_ID]) + dataset_id_segment = get_system_segment(variable_pool, SystemVariableKey.DATASET_ID) if not dataset_id_segment: raise KnowledgeIndexNodeError("Dataset ID is required.") dataset_id: str = dataset_id_segment.value # get document id as string (may be empty when not provided) - document_id_segment = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id_segment = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) document_id: str = document_id_segment.value if document_id_segment else "" # extract variables variable = variable_pool.get(node_data.index_chunk_variable_selector) if not variable: raise KnowledgeIndexNodeError("Index chunk variable is required.") - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) - invoke_from_value = str(invoke_from.value) if invoke_from else None + invoke_from_value = get_system_text(variable_pool, SystemVariableKey.INVOKE_FROM) is_preview = invoke_from_value == _INVOKE_FROM_DEBUGGER chunks = variable.value @@ -87,8 +87,8 @@ class KnowledgeIndexNode(Node[KnowledgeIndexNodeData]): outputs=outputs.model_dump(exclude_none=True), ) - original_document_id_segment = variable_pool.get(["sys", SystemVariableKey.ORIGINAL_DOCUMENT_ID]) - batch = variable_pool.get(["sys", SystemVariableKey.BATCH]) + original_document_id_segment = get_system_segment(variable_pool, SystemVariableKey.ORIGINAL_DOCUMENT_ID) + batch = get_system_segment(variable_pool, SystemVariableKey.BATCH) if not batch: raise KnowledgeIndexNodeError("Batch is required.") diff --git a/api/core/workflow/nodes/knowledge_retrieval/entities.py b/api/core/workflow/nodes/knowledge_retrieval/entities.py index bc5618685a..b1fa8593ef 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/entities.py +++ b/api/core/workflow/nodes/knowledge_retrieval/entities.py @@ -1,12 +1,11 @@ from collections.abc import Sequence from typing import Literal +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.llm.entities import ModelConfig, VisionConfig from pydantic import BaseModel, Field -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig - class RerankingModelConfig(BaseModel): """ diff --git a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py index 80f59140be..13624b27b3 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py +++ b/api/core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py @@ -8,27 +8,30 @@ import logging from collections.abc import Mapping, Sequence from typing import TYPE_CHECKING, Any, Literal -from core.app.app_config.entities import DatasetRetrieveConfigEntity -from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from core.rag.retrieval.dataset_retrieval import DatasetRetrieval -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict +from graphon.enums import ( BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.variables import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.node_events import NodeRunResult +from graphon.nodes.base import LLMUsageTrackingMixin +from graphon.nodes.base.node import Node +from graphon.variables import ( ArrayFileSegment, FileSegment, StringSegment, ) -from dify_graph.variables.segments import ArrayObjectSegment +from graphon.variables.segments import ArrayObjectSegment + +from core.app.app_config.entities import DatasetRetrieveConfigEntity +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext +from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict +from core.rag.retrieval.dataset_retrieval import DatasetRetrieval +from core.workflow.file_reference import parse_file_reference from .entities import ( Condition, @@ -42,8 +45,8 @@ from .exc import ( from .retrieval import KnowledgeRetrievalRequest, Source if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState + from graphon.file import File + from graphon.runtime import GraphRuntimeState logger = logging.getLogger(__name__) @@ -160,7 +163,7 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD def _fetch_dataset_retriever( self, node_data: KnowledgeRetrievalNodeData, variables: dict[str, Any] ) -> tuple[list[Source], LLMUsage]: - dify_ctx = self.require_dify_context() + dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY)) dataset_ids = node_data.dataset_ids query = variables.get("query") attachments = variables.get("attachments") @@ -254,7 +257,13 @@ class KnowledgeRetrievalNode(LLMUsageTrackingMixin, Node[KnowledgeRetrievalNodeD metadata_model_config=node_data.metadata_model_config, metadata_filtering_conditions=resolved_metadata_conditions, metadata_filtering_mode=metadata_filtering_mode, - attachment_ids=[attachment.related_id for attachment in attachments] if attachments else None, + attachment_ids=[ + parsed_reference.record_id + for attachment in attachments + if (parsed_reference := parse_file_reference(attachment.reference)) is not None + ] + if attachments + else None, ) ) diff --git a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py index e1311ab962..39e2008a2c 100644 --- a/api/core/workflow/nodes/knowledge_retrieval/retrieval.py +++ b/api/core/workflow/nodes/knowledge_retrieval/retrieval.py @@ -1,10 +1,10 @@ from typing import Any, Literal, Protocol +from graphon.model_runtime.entities import LLMUsage +from graphon.nodes.llm.entities import ModelConfig from pydantic import BaseModel, Field from core.rag.data_post_processor.data_post_processor import RerankingModelDict, WeightsDict -from dify_graph.model_runtime.entities import LLMUsage -from dify_graph.nodes.llm.entities import ModelConfig from .entities import MetadataFilteringCondition @@ -54,7 +54,7 @@ class KnowledgeRetrievalRequest(BaseModel): tenant_id: str = Field(description="Tenant unique identifier") user_id: str = Field(description="User unique identifier") app_id: str = Field(description="Application unique identifier") - user_from: str = Field(description="Source of the user request (e.g., 'workflow', 'api')") + user_from: str = Field(description="User identity source for audit logging (e.g., 'account', 'end-user')") dataset_ids: list[str] = Field(description="List of dataset IDs to retrieve from") query: str | None = Field(default=None, description="Query text for knowledge retrieval") retrieval_mode: str = Field(description="Retrieval strategy: 'single' or 'multiple'") diff --git a/api/core/workflow/nodes/trigger_plugin/entities.py b/api/core/workflow/nodes/trigger_plugin/entities.py index ea7d20befe..bf5be2379a 100644 --- a/api/core/workflow/nodes/trigger_plugin/entities.py +++ b/api/core/workflow/nodes/trigger_plugin/entities.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any, Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field, ValidationInfo, field_validator from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.trigger.entities.entities import EventParameter -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType from .exc import TriggerEventParameterError diff --git a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py index 118c2f2668..e50de11bb9 100644 --- a/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py +++ b/api/core/workflow/nodes/trigger_plugin/trigger_event_node.py @@ -1,12 +1,12 @@ from collections.abc import Mapping from typing import Any +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node + from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType, WorkflowNodeExecutionMetadataKey -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from .entities import TriggerEventNodeData @@ -53,13 +53,11 @@ class TriggerEventNode(Node[TriggerEventNodeData]): "plugin_unique_identifier": self.node_data.plugin_unique_identifier, }, } - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_schedule/entities.py b/api/core/workflow/nodes/trigger_schedule/entities.py index 95a2548678..f14ca893c9 100644 --- a/api/core/workflow/nodes/trigger_schedule/entities.py +++ b/api/core/workflow/nodes/trigger_schedule/entities.py @@ -1,10 +1,10 @@ from typing import Literal, Union +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType from pydantic import BaseModel, Field from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType class TriggerScheduleNodeData(BaseNodeData): diff --git a/api/core/workflow/nodes/trigger_schedule/exc.py b/api/core/workflow/nodes/trigger_schedule/exc.py index 336d64d58f..10962c3de4 100644 --- a/api/core/workflow/nodes/trigger_schedule/exc.py +++ b/api/core/workflow/nodes/trigger_schedule/exc.py @@ -1,4 +1,4 @@ -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError class ScheduleNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py index b9580e6ab1..a9753ab387 100644 --- a/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py +++ b/api/core/workflow/nodes/trigger_schedule/trigger_schedule_node.py @@ -1,11 +1,11 @@ from collections.abc import Mapping +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node + from core.trigger.constants import TRIGGER_SCHEDULE_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from .entities import TriggerScheduleNodeData @@ -31,13 +31,11 @@ class TriggerScheduleNode(Node[TriggerScheduleNodeData]): } def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + node_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + node_inputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value outputs = dict(node_inputs) return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, diff --git a/api/core/workflow/nodes/trigger_webhook/entities.py b/api/core/workflow/nodes/trigger_webhook/entities.py index 242bf5ef6a..4d5ad72154 100644 --- a/api/core/workflow/nodes/trigger_webhook/entities.py +++ b/api/core/workflow/nodes/trigger_webhook/entities.py @@ -1,12 +1,12 @@ from collections.abc import Sequence from enum import StrEnum +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import NodeType +from graphon.variables.types import SegmentType from pydantic import BaseModel, Field, field_validator from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import NodeType -from dify_graph.variables.types import SegmentType _WEBHOOK_HEADER_ALLOWED_TYPES = frozenset( { diff --git a/api/core/workflow/nodes/trigger_webhook/exc.py b/api/core/workflow/nodes/trigger_webhook/exc.py index 4d87f2a069..00b0b3baad 100644 --- a/api/core/workflow/nodes/trigger_webhook/exc.py +++ b/api/core/workflow/nodes/trigger_webhook/exc.py @@ -1,4 +1,4 @@ -from dify_graph.entities.exc import BaseNodeError +from graphon.entities.exc import BaseNodeError class WebhookNodeError(BaseNodeError): diff --git a/api/core/workflow/nodes/trigger_webhook/node.py b/api/core/workflow/nodes/trigger_webhook/node.py index 317844cbda..ebaac93934 100644 --- a/api/core/workflow/nodes/trigger_webhook/node.py +++ b/api/core/workflow/nodes/trigger_webhook/node.py @@ -2,16 +2,17 @@ import logging from collections.abc import Mapping from typing import Any +from graphon.enums import NodeExecutionType, WorkflowNodeExecutionStatus +from graphon.file import FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.protocols import FileReferenceFactoryProtocol +from graphon.variables.types import SegmentType +from graphon.variables.variables import FileVariable + from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.enums import NodeExecutionType -from dify_graph.file import FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import FileVariable -from factories import file_factory +from core.workflow.file_reference import resolve_file_record_id +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment_with_type from .entities import ContentType, WebhookData @@ -23,6 +24,13 @@ class TriggerWebhookNode(Node[WebhookData]): node_type = TRIGGER_WEBHOOK_NODE_TYPE execution_type = NodeExecutionType.ROOT + _file_reference_factory: FileReferenceFactoryProtocol + + def post_init(self) -> None: + from core.workflow.node_runtime import DifyFileReferenceFactory + + self._file_reference_factory = DifyFileReferenceFactory(self.graph_init_params.run_context) + @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: return { @@ -53,16 +61,14 @@ class TriggerWebhookNode(Node[WebhookData]): happens in the trigger controller. """ # Get webhook data from variable pool (injected by Celery task) - webhook_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) + webhook_inputs = dict(self.graph_runtime_state.variable_pool.get_by_prefix(self.id)) # Extract webhook-specific outputs based on node configuration outputs = self._extract_configured_outputs(webhook_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() + system_inputs = self.graph_runtime_state.variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - outputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] + for variable_name, value in system_inputs.items(): + outputs[f"{SYSTEM_VARIABLE_NODE_ID}.{variable_name}"] = value return NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=webhook_inputs, @@ -70,24 +76,20 @@ class TriggerWebhookNode(Node[WebhookData]): ) def generate_file_var(self, param_name: str, file: dict): - dify_ctx = self.require_dify_context() - related_id = file.get("related_id") + file_id = resolve_file_record_id(file.get("reference") or file.get("related_id")) transfer_method_value = file.get("transfer_method") if transfer_method_value: transfer_method = FileTransferMethod.value_of(transfer_method_value) match transfer_method: case FileTransferMethod.LOCAL_FILE | FileTransferMethod.REMOTE_URL: - file["upload_file_id"] = related_id + file["upload_file_id"] = file_id case FileTransferMethod.TOOL_FILE: - file["tool_file_id"] = related_id + file["tool_file_id"] = file_id case FileTransferMethod.DATASOURCE_FILE: - file["datasource_file_id"] = related_id + file["datasource_file_id"] = file_id try: - file_obj = file_factory.build_from_mapping( - mapping=file, - tenant_id=dify_ctx.tenant_id, - ) + file_obj = self._file_reference_factory.build_from_mapping(mapping=file) file_segment = build_segment_with_type(SegmentType.FILE, file_obj) return FileVariable(name=param_name, value=file_segment.value, selector=[self.id, param_name]) except ValueError: diff --git a/api/core/workflow/system_variables.py b/api/core/workflow/system_variables.py new file mode 100644 index 0000000000..9d15a3fcea --- /dev/null +++ b/api/core/workflow/system_variables.py @@ -0,0 +1,236 @@ +from __future__ import annotations + +from collections import defaultdict +from collections.abc import Mapping, Sequence +from enum import StrEnum +from typing import Any, Protocol, cast +from uuid import uuid4 + +from graphon.enums import BuiltinNodeTypes +from graphon.variables import build_segment, segment_to_variable +from graphon.variables.segments import Segment +from graphon.variables.variables import RAGPipelineVariableInput, Variable + +from .variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) + + +class SystemVariableKey(StrEnum): + QUERY = "query" + FILES = "files" + CONVERSATION_ID = "conversation_id" + USER_ID = "user_id" + DIALOGUE_COUNT = "dialogue_count" + APP_ID = "app_id" + WORKFLOW_ID = "workflow_id" + WORKFLOW_EXECUTION_ID = "workflow_run_id" + TIMESTAMP = "timestamp" + DOCUMENT_ID = "document_id" + ORIGINAL_DOCUMENT_ID = "original_document_id" + BATCH = "batch" + DATASET_ID = "dataset_id" + DATASOURCE_TYPE = "datasource_type" + DATASOURCE_INFO = "datasource_info" + INVOKE_FROM = "invoke_from" + + +class _VariablePoolReader(Protocol): + def get(self, selector: Sequence[str], /) -> Segment | None: ... + + def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: ... + + +class _VariablePoolWriter(_VariablePoolReader, Protocol): + def add(self, selector: Sequence[str], value: object, /) -> None: ... + + +class _VariableLoader(Protocol): + def load_variables(self, selectors: list[list[str]]) -> Sequence[object]: ... + + +def system_variable_name(key: str | SystemVariableKey) -> str: + return key.value if isinstance(key, SystemVariableKey) else key + + +def system_variable_selector(key: str | SystemVariableKey) -> tuple[str, str]: + return SYSTEM_VARIABLE_NODE_ID, system_variable_name(key) + + +def _normalize_system_variable_values(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> dict[str, Any]: + raw_values = dict(values or {}) + raw_values.update(kwargs) + + workflow_execution_id = raw_values.pop("workflow_execution_id", None) + if workflow_execution_id is not None and SystemVariableKey.WORKFLOW_EXECUTION_ID.value not in raw_values: + raw_values[SystemVariableKey.WORKFLOW_EXECUTION_ID.value] = workflow_execution_id + + normalized: dict[str, Any] = {} + for key, value in raw_values.items(): + if value is None: + continue + normalized[system_variable_name(key)] = value + + normalized.setdefault(SystemVariableKey.FILES.value, []) + return normalized + + +def build_system_variables(values: Mapping[str, Any] | None = None, /, **kwargs: Any) -> list[Variable]: + normalized = _normalize_system_variable_values(values, **kwargs) + + return [ + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=system_variable_selector(key), + name=key, + ), + ) + for key, value in normalized.items() + ] + + +def default_system_variables() -> list[Variable]: + return build_system_variables(workflow_run_id=str(uuid4())) + + +def system_variables_to_mapping(system_variables: Sequence[Variable]) -> dict[str, Any]: + return {variable.name: variable.value for variable in system_variables} + + +def _with_selector(variable: Variable, node_id: str) -> Variable: + selector = [node_id, variable.name] + if list(variable.selector) == selector: + return variable + return variable.model_copy(update={"selector": selector}) + + +def build_bootstrap_variables( + *, + system_variables: Sequence[Variable] = (), + environment_variables: Sequence[Variable] = (), + conversation_variables: Sequence[Variable] = (), + rag_pipeline_variables: Sequence[RAGPipelineVariableInput] = (), +) -> list[Variable]: + variables = [ + *(_with_selector(variable, SYSTEM_VARIABLE_NODE_ID) for variable in system_variables), + *(_with_selector(variable, ENVIRONMENT_VARIABLE_NODE_ID) for variable in environment_variables), + *(_with_selector(variable, CONVERSATION_VARIABLE_NODE_ID) for variable in conversation_variables), + ] + + rag_pipeline_variables_map: defaultdict[str, dict[str, Any]] = defaultdict(dict) + for rag_var in rag_pipeline_variables: + node_id = rag_var.variable.belong_to_node_id + key = rag_var.variable.variable + rag_pipeline_variables_map[node_id][key] = rag_var.value + + for node_id, value in rag_pipeline_variables_map.items(): + variables.append( + cast( + Variable, + segment_to_variable( + segment=build_segment(value), + selector=(RAG_PIPELINE_VARIABLE_NODE_ID, node_id), + name=node_id, + ), + ) + ) + + return variables + + +def get_system_segment(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Segment | None: + return variable_pool.get(system_variable_selector(key)) + + +def get_system_value(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> Any: + segment = get_system_segment(variable_pool, key) + return None if segment is None else segment.value + + +def get_system_text(variable_pool: _VariablePoolReader, key: str | SystemVariableKey) -> str | None: + segment = get_system_segment(variable_pool, key) + if segment is None: + return None + text = getattr(segment, "text", None) + return text if isinstance(text, str) else None + + +def get_all_system_variables(variable_pool: _VariablePoolReader) -> Mapping[str, object]: + return variable_pool.get_by_prefix(SYSTEM_VARIABLE_NODE_ID) + + +_MEMORY_BOOTSTRAP_NODE_TYPES = frozenset( + ( + BuiltinNodeTypes.LLM, + BuiltinNodeTypes.QUESTION_CLASSIFIER, + BuiltinNodeTypes.PARAMETER_EXTRACTOR, + ) +) + + +def get_node_creation_preload_selectors( + *, + node_type: str, + node_data: object, +) -> tuple[tuple[str, str], ...]: + """Return selectors that must exist before node construction begins.""" + + if node_type not in _MEMORY_BOOTSTRAP_NODE_TYPES or getattr(node_data, "memory", None) is None: + return () + + return (system_variable_selector(SystemVariableKey.CONVERSATION_ID),) + + +def preload_node_creation_variables( + *, + variable_loader: _VariableLoader, + variable_pool: _VariablePoolWriter, + selectors: Sequence[Sequence[str]], +) -> None: + """Load constructor-time variables before node or graph creation.""" + + seen_selectors: set[tuple[str, ...]] = set() + selectors_to_load: list[list[str]] = [] + for selector in selectors: + normalized_selector = tuple(selector) + if len(normalized_selector) < 2: + raise ValueError(f"Invalid preload selector: {selector}") + if normalized_selector in seen_selectors: + continue + seen_selectors.add(normalized_selector) + if variable_pool.get(normalized_selector) is None: + selectors_to_load.append(list(normalized_selector)) + + loaded_variables = variable_loader.load_variables(selectors_to_load) + for variable in loaded_variables: + raw_selector = getattr(variable, "selector", ()) + loaded_selector = list(raw_selector) + if len(loaded_selector) < 2: + raise ValueError(f"Invalid loaded variable selector: {raw_selector}") + variable_pool.add(loaded_selector[:2], variable) + + +def inject_default_system_variable_mappings( + *, + node_id: str, + node_type: str, + node_data: object, + variable_mapping: Mapping[str, Sequence[str]], +) -> Mapping[str, Sequence[str]]: + """Add workflow-owned implicit sys mappings that `graphon` should not know about.""" + + if node_type != BuiltinNodeTypes.LLM or getattr(node_data, "memory", None) is None: + return variable_mapping + + query_mapping_key = f"{node_id}.#sys.query#" + if query_mapping_key in variable_mapping: + return variable_mapping + + augmented_mapping = dict(variable_mapping) + augmented_mapping[query_mapping_key] = system_variable_selector(SystemVariableKey.QUERY) + return augmented_mapping diff --git a/api/core/workflow/template_rendering.py b/api/core/workflow/template_rendering.py new file mode 100644 index 0000000000..d51cfadd09 --- /dev/null +++ b/api/core/workflow/template_rendering.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from graphon.nodes.code.entities import CodeLanguage +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError + +from core.helper.code_executor.code_executor import CodeExecutionError, CodeExecutor + + +class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): + """Sandbox-backed Jinja2 renderer for workflow-owned node composition.""" + + def render_template(self, template: str, variables: Mapping[str, Any]) -> str: + try: + result = CodeExecutor.execute_workflow_code_template( + language=CodeLanguage.JINJA2, + code=template, + inputs=variables, + ) + except Exception as exc: + if isinstance(exc, CodeExecutionError): + raise TemplateRenderError(str(exc)) from exc + raise + + rendered = result.get("result") + if not isinstance(rendered, str): + raise TemplateRenderError("Template render result must be a string.") + return rendered diff --git a/api/core/workflow/variable_pool_initializer.py b/api/core/workflow/variable_pool_initializer.py new file mode 100644 index 0000000000..43523e01b2 --- /dev/null +++ b/api/core/workflow/variable_pool_initializer.py @@ -0,0 +1,15 @@ +from collections.abc import Mapping, Sequence +from typing import Any + +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable + + +def add_variables_to_pool(variable_pool: VariablePool, variables: Sequence[Variable]) -> None: + for variable in variables: + variable_pool.add(variable.selector, variable) + + +def add_node_inputs_to_pool(variable_pool: VariablePool, *, node_id: str, inputs: Mapping[str, Any]) -> None: + for key, value in inputs.items(): + variable_pool.add((node_id, key), value) diff --git a/api/dify_graph/constants.py b/api/core/workflow/variable_prefixes.py similarity index 100% rename from api/dify_graph/constants.py rename to api/core/workflow/variable_prefixes.py diff --git a/api/core/workflow/workflow_entry.py b/api/core/workflow/workflow_entry.py index 2e51a06bab..2346a95d6a 100644 --- a/api/core/workflow/workflow_entry.py +++ b/api/core/workflow/workflow_entry.py @@ -1,36 +1,44 @@ import logging import time from collections.abc import Generator, Mapping, Sequence -from typing import Any, cast +from typing import Any + +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import CommandChannel, InMemoryChannel +from graphon.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer +from graphon.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool +from graphon.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool from configs import dify_config +from context import capture_current_context from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.app.workflow.layers.observability import ObservabilityLayer -from core.workflow.node_factory import DifyNodeFactory, resolve_workflow_node_class -from dify_graph.constants import ENVIRONMENT_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.file.models import File -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers import DebugLoggingLayer, ExecutionLimitsLayer -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase, GraphRunFailedEvent -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader, load_into_variable_pool +from core.workflow.node_factory import DifyNodeFactory, is_start_node_type, resolve_workflow_node_class +from core.workflow.system_variables import ( + default_system_variables, + get_node_creation_preload_selectors, + inject_default_system_variable_mappings, + preload_node_creation_variables, +) +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.variable_prefixes import ENVIRONMENT_VARIABLE_NODE_ID from extensions.otel.runtime import is_instrument_flag_enabled from factories import file_factory from models.workflow import Workflow logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class _WorkflowChildEngineBuilder: @@ -59,16 +67,22 @@ class _WorkflowChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + """Build a child engine with a fresh runtime state and only child-safe layers.""" + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) node_factory = DifyNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, ) + graph_config = graph_init_params.graph_config has_root_node = self._has_node_id(graph_config=graph_config, node_id=root_node_id) if has_root_node is False: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") @@ -79,17 +93,17 @@ class _WorkflowChildEngineBuilder: root_node_id=root_node_id, ) + command_channel = InMemoryChannel() + config = GraphEngineConfig() child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), + graph_runtime_state=child_graph_runtime_state, + command_channel=command_channel, + config=config, child_engine_builder=self, ) child_engine.layer(LLMQuotaLayer()) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -136,6 +150,8 @@ class WorkflowEntry: command_channel = InMemoryChannel() self.command_channel = command_channel + execution_context = capture_current_context() + graph_runtime_state.execution_context = execution_context self._child_engine_builder = _WorkflowChildEngineBuilder() self.graph_engine = GraphEngine( workflow_id=workflow_id, @@ -212,6 +228,8 @@ class WorkflowEntry: # Get node type node_type = node_config_data.type + node_version = str(node_config_data.version) + node_cls = resolve_workflow_node_class(node_type=node_type, node_version=node_version) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -226,15 +244,23 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # init workflow run state - node_factory = DifyNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) + + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) + + preload_node_creation_variables( + variable_loader=variable_loader, + variable_pool=variable_pool, + selectors=get_node_creation_preload_selectors( + node_type=node_type, + node_data=node_config_data, + ), ) - node = node_factory.create_node(node_config) - node_cls = type(node) try: # variable selector to variable mapping @@ -243,6 +269,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_config_data, + variable_mapping=variable_mapping, + ) # Loading missing variable from draft var here, and set it into # variable_pool. @@ -260,6 +292,13 @@ class WorkflowEntry: tenant_id=workflow.tenant_id, ) + # init workflow run state + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=graph_runtime_state, + ) + node = node_factory.create_node(node_config) + try: generator = cls._traced_node_run(node) except Exception as e: @@ -347,11 +386,8 @@ class WorkflowEntry: raise ValueError(f"Node class not found for node type {node_type}") # init variable pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, default_system_variables()) # init graph init params and runtime state graph_init_params = GraphInitParams( @@ -366,7 +402,11 @@ class WorkflowEntry: ), call_depth=0, ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool, + start_at=time.perf_counter(), + execution_context=capture_current_context(), + ) # init workflow run state node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data}) @@ -384,6 +424,12 @@ class WorkflowEntry: ) except NotImplementedError: variable_mapping = {} + variable_mapping = inject_default_system_variable_mappings( + node_id=node_id, + node_type=node_type, + node_data=node_data, + variable_mapping=variable_mapping, + ) cls.mapping_user_inputs_to_variable_pool( variable_mapping=variable_mapping, @@ -477,13 +523,21 @@ class WorkflowEntry: continue if isinstance(input_value, dict) and "type" in input_value and "transfer_method" in input_value: - input_value = file_factory.build_from_mapping(mapping=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mapping( + mapping=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) if ( isinstance(input_value, list) and all(isinstance(item, dict) for item in input_value) and all("type" in item and "transfer_method" in item for item in input_value) ): - input_value = file_factory.build_from_mappings(mappings=input_value, tenant_id=tenant_id) + input_value = file_factory.build_from_mappings( + mappings=input_value, + tenant_id=tenant_id, + access_controller=_file_access_controller, + ) # append variable and value to variable pool if variable_node_id != ENVIRONMENT_VARIABLE_NODE_ID: diff --git a/api/core/workflow/workflow_run_outputs.py b/api/core/workflow/workflow_run_outputs.py new file mode 100644 index 0000000000..bd89f7c441 --- /dev/null +++ b/api/core/workflow/workflow_run_outputs.py @@ -0,0 +1,18 @@ +from collections.abc import Mapping +from typing import Any + +from graphon.enums import BuiltinNodeTypes, NodeType + + +def project_node_outputs_for_workflow_run( + *, + node_type: NodeType, + inputs: Mapping[str, Any], + outputs: Mapping[str, Any], +) -> dict[str, Any]: + """Project internal node outputs onto the workflow-run public contract.""" + + if node_type == BuiltinNodeTypes.START: + return dict(inputs) + + return dict(outputs) diff --git a/api/dify_graph/README.md b/api/dify_graph/README.md deleted file mode 100644 index 2fc5b8b890..0000000000 --- a/api/dify_graph/README.md +++ /dev/null @@ -1,135 +0,0 @@ -# Workflow - -## Project Overview - -This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control. - -## Architecture - -### Core Components - -The graph engine follows a layered architecture with strict dependency rules: - -1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution - - - **Manager** - External control interface for stop/pause/resume commands - - **Worker** - Node execution runtime - - **Command Processing** - Handles control commands (abort, pause, resume) - - **Event Management** - Event propagation and layer notifications - - **Graph Traversal** - Edge processing and skip propagation - - **Response Coordinator** - Path tracking and session management - - **Layers** - Pluggable middleware (debug logging, execution limits) - - **Command Channels** - Communication channels (InMemory, Redis) - -1. **Graph** (`graph/`) - Graph structure and runtime state - - - **Graph Template** - Workflow definition - - **Edge** - Node connections with conditions - - **Runtime State Protocol** - State management interface - -1. **Nodes** (`nodes/`) - Node implementations - - - **Base** - Abstract node classes and variable parsing - - **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc. - -1. **Events** (`node_events/`) - Event system - - - **Base** - Event protocols - - **Node Events** - Node lifecycle events - -1. **Entities** (`entities/`) - Domain models - - - **Variable Pool** - Variable storage - - **Graph Init Params** - Initialization configuration - -## Key Design Patterns - -### Command Channel Pattern - -External workflow control via Redis or in-memory channels: - -```python -# Send stop command to running workflow -channel = RedisChannel(redis_client, f"workflow:{task_id}:commands") -channel.send_command(AbortCommand(reason="User requested")) -``` - -### Layer System - -Extensible middleware for cross-cutting concerns: - -```python -engine = GraphEngine(graph) -engine.layer(DebugLoggingLayer(level="INFO")) -engine.layer(ExecutionLimitsLayer(max_nodes=100)) -``` - -`engine.layer()` binds the read-only runtime state before execution, so layer hooks -can assume `graph_runtime_state` is available. - -### Event-Driven Architecture - -All node executions emit events for monitoring and integration: - -- `NodeRunStartedEvent` - Node execution begins -- `NodeRunSucceededEvent` - Node completes successfully -- `NodeRunFailedEvent` - Node encounters error -- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle - -### Variable Pool - -Centralized variable storage with namespace isolation: - -```python -# Variables scoped by node_id -pool.add(["node1", "output"], value) -result = pool.get(["node1", "output"]) -``` - -## Import Architecture Rules - -The codebase enforces strict layering via import-linter: - -1. **Workflow Layers** (top to bottom): - - - graph_engine → graph_events → graph → nodes → node_events → entities - -1. **Graph Engine Internal Layers**: - - - orchestration → command_processing → event_management → graph_traversal → domain - -1. **Domain Isolation**: - - - Domain models cannot import from infrastructure layers - -1. **Command Channel Independence**: - - - InMemory and Redis channels must remain independent - -## Common Tasks - -### Adding a New Node Type - -1. Create node class in `nodes//` -1. Inherit from `BaseNode` or appropriate base class -1. Implement `_run()` method -1. Ensure the node module is importable under `nodes//` -1. Add tests in `tests/unit_tests/dify_graph/nodes/` - -### Implementing a Custom Layer - -1. Create class inheriting from `Layer` base -1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()` -1. Add to engine via `engine.layer()` - -### Debugging Workflow Execution - -Enable debug logging layer: - -```python -debug_layer = DebugLoggingLayer( - level="DEBUG", - include_inputs=True, - include_outputs=True -) -``` diff --git a/api/dify_graph/context/__init__.py b/api/dify_graph/context/__init__.py deleted file mode 100644 index 103f526bec..0000000000 --- a/api/dify_graph/context/__init__.py +++ /dev/null @@ -1,34 +0,0 @@ -""" -Execution Context - Context management for workflow execution. - -This package provides Flask-independent context management for workflow -execution in multi-threaded environments. -""" - -from dify_graph.context.execution_context import ( - AppContext, - ContextProviderNotFoundError, - ExecutionContext, - IExecutionContext, - NullAppContext, - capture_current_context, - read_context, - register_context, - register_context_capturer, - reset_context_provider, -) -from dify_graph.context.models import SandboxContext - -__all__ = [ - "AppContext", - "ContextProviderNotFoundError", - "ExecutionContext", - "IExecutionContext", - "NullAppContext", - "SandboxContext", - "capture_current_context", - "read_context", - "register_context", - "register_context_capturer", - "reset_context_provider", -] diff --git a/api/dify_graph/conversation_variable_updater.py b/api/dify_graph/conversation_variable_updater.py deleted file mode 100644 index 17b19f2502..0000000000 --- a/api/dify_graph/conversation_variable_updater.py +++ /dev/null @@ -1,39 +0,0 @@ -import abc -from typing import Protocol - -from dify_graph.variables import VariableBase - - -class ConversationVariableUpdater(Protocol): - """ - ConversationVariableUpdater defines an abstraction for updating conversation variable values. - - It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating - conversation variables. - - Implementations may choose to batch updates. If batching is used, the `flush` method - should be implemented to persist buffered changes, and `update` - should handle buffering accordingly. - - Note: Since implementations may buffer updates, instances of ConversationVariableUpdater - are not thread-safe. Each VariableAssignerNode should create its own instance during execution. - """ - - @abc.abstractmethod - def update(self, conversation_id: str, variable: "VariableBase"): - """ - Updates the value of the specified conversation variable in the underlying storage. - - :param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`. - :param variable: The `VariableBase` instance containing the updated value. - """ - pass - - @abc.abstractmethod - def flush(self): - """ - Flushes all pending updates to the underlying storage system. - - If the implementation does not buffer updates, this method can be a no-op. - """ - pass diff --git a/api/dify_graph/entities/__init__.py b/api/dify_graph/entities/__init__.py deleted file mode 100644 index ef7789c49c..0000000000 --- a/api/dify_graph/entities/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .graph_init_params import GraphInitParams -from .workflow_execution import WorkflowExecution -from .workflow_node_execution import WorkflowNodeExecution -from .workflow_start_reason import WorkflowStartReason - -__all__ = [ - "GraphInitParams", - "WorkflowExecution", - "WorkflowNodeExecution", - "WorkflowStartReason", -] diff --git a/api/dify_graph/entities/base_node_data.py b/api/dify_graph/entities/base_node_data.py deleted file mode 100644 index 47b37c9daf..0000000000 --- a/api/dify_graph/entities/base_node_data.py +++ /dev/null @@ -1,178 +0,0 @@ -from __future__ import annotations - -import json -from abc import ABC -from builtins import type as type_ -from enum import StrEnum -from typing import Any, Union - -from pydantic import BaseModel, ConfigDict, Field, model_validator - -from dify_graph.entities.exc import DefaultValueTypeError -from dify_graph.enums import ErrorStrategy, NodeType - -# Project supports Python 3.11+, where `typing.Union[...]` is valid in `isinstance`. -_NumberType = Union[int, float] - - -class RetryConfig(BaseModel): - """node retry config""" - - max_retries: int = 0 # max retry times - retry_interval: int = 0 # retry interval in milliseconds - retry_enabled: bool = False # whether retry is enabled - - @property - def retry_interval_seconds(self) -> float: - return self.retry_interval / 1000 - - -class DefaultValueType(StrEnum): - STRING = "string" - NUMBER = "number" - OBJECT = "object" - ARRAY_NUMBER = "array[number]" - ARRAY_STRING = "array[string]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILES = "array[file]" - - -class DefaultValue(BaseModel): - value: Any = None - type: DefaultValueType - key: str - - @staticmethod - def _parse_json(value: str): - """Unified JSON parsing handler""" - try: - return json.loads(value) - except json.JSONDecodeError: - raise DefaultValueTypeError(f"Invalid JSON format for value: {value}") - - @staticmethod - def _validate_array(value: Any, element_type: type_ | tuple[type_, ...]) -> bool: - """Unified array type validation""" - return isinstance(value, list) and all(isinstance(x, element_type) for x in value) - - @staticmethod - def _convert_number(value: str) -> float: - """Unified number conversion handler""" - try: - return float(value) - except ValueError: - raise DefaultValueTypeError(f"Cannot convert to number: {value}") - - @model_validator(mode="after") - def validate_value_type(self) -> DefaultValue: - # Type validation configuration - type_validators: dict[DefaultValueType, dict[str, Any]] = { - DefaultValueType.STRING: { - "type": str, - "converter": lambda x: x, - }, - DefaultValueType.NUMBER: { - "type": _NumberType, - "converter": self._convert_number, - }, - DefaultValueType.OBJECT: { - "type": dict, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_NUMBER: { - "type": list, - "element_type": _NumberType, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_STRING: { - "type": list, - "element_type": str, - "converter": self._parse_json, - }, - DefaultValueType.ARRAY_OBJECT: { - "type": list, - "element_type": dict, - "converter": self._parse_json, - }, - } - - validator: dict[str, Any] = type_validators.get(self.type, {}) - if not validator: - if self.type == DefaultValueType.ARRAY_FILES: - # Handle files type - return self - raise DefaultValueTypeError(f"Unsupported type: {self.type}") - - # Handle string input cases - if isinstance(self.value, str) and self.type != DefaultValueType.STRING: - self.value = validator["converter"](self.value) - - # Validate base type - if not isinstance(self.value, validator["type"]): - raise DefaultValueTypeError(f"Value must be {validator['type'].__name__} type for {self.value}") - - # Validate array element types - if validator["type"] == list and not self._validate_array(self.value, validator["element_type"]): - raise DefaultValueTypeError(f"All elements must be {validator['element_type'].__name__} for {self.value}") - - return self - - -class BaseNodeData(ABC, BaseModel): - # Raw graph payloads are first validated through `NodeConfigDictAdapter`, where - # `node["data"]` is typed as `BaseNodeData` before the concrete node class is known. - # `type` therefore accepts downstream string node kinds; unknown node implementations - # are rejected later when the node factory resolves the node registry. - # At that boundary, node-specific fields are still "extra" relative to this shared DTO, - # and persisted templates/workflows also carry undeclared compatibility keys such as - # `selected`, `params`, `paramSchemas`, and `datasource_label`. Keep extras permissive - # here until graph parsing becomes discriminated by node type or those legacy payloads - # are normalized. - model_config = ConfigDict(extra="allow") - - type: NodeType - title: str = "" - desc: str | None = None - version: str = "1" - error_strategy: ErrorStrategy | None = None - default_value: list[DefaultValue] | None = None - retry_config: RetryConfig = Field(default_factory=RetryConfig) - - @property - def default_value_dict(self) -> dict[str, Any]: - if self.default_value: - return {item.key: item.value for item in self.default_value} - return {} - - def __getitem__(self, key: str) -> Any: - """ - Dict-style access without calling model_dump() on every lookup. - Prefer using model fields and Pydantic's extra storage. - """ - # First, check declared model fields - if key in self.__class__.model_fields: - return getattr(self, key) - - # Then, check undeclared compatibility fields stored in Pydantic's extra dict. - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras[key] - - raise KeyError(key) - - def get(self, key: str, default: Any = None) -> Any: - """ - Dict-style .get() without calling model_dump() on every lookup. - """ - if key in self.__class__.model_fields: - return getattr(self, key) - - extras = getattr(self, "__pydantic_extra__", None) - if extras is None: - extras = getattr(self, "model_extra", None) - if extras is not None and key in extras: - return extras.get(key, default) - - return default diff --git a/api/dify_graph/entities/exc.py b/api/dify_graph/entities/exc.py deleted file mode 100644 index aeecf40640..0000000000 --- a/api/dify_graph/entities/exc.py +++ /dev/null @@ -1,10 +0,0 @@ -class BaseNodeError(ValueError): - """Base class for node errors.""" - - pass - - -class DefaultValueTypeError(BaseNodeError): - """Raised when the default value type is invalid.""" - - pass diff --git a/api/dify_graph/entities/graph_config.py b/api/dify_graph/entities/graph_config.py deleted file mode 100644 index 36f7b94e82..0000000000 --- a/api/dify_graph/entities/graph_config.py +++ /dev/null @@ -1,23 +0,0 @@ -from __future__ import annotations - -import sys - -from pydantic import TypeAdapter, with_config - -from dify_graph.entities.base_node_data import BaseNodeData - -if sys.version_info >= (3, 12): - from typing import TypedDict -else: - from typing_extensions import TypedDict - - -@with_config(extra="allow") -class NodeConfigDict(TypedDict): - id: str - # This is the permissive raw graph boundary. Node factories re-validate `data` - # with the concrete `NodeData` subtype after resolving the node implementation. - data: BaseNodeData - - -NodeConfigDictAdapter = TypeAdapter(NodeConfigDict) diff --git a/api/dify_graph/entities/graph_init_params.py b/api/dify_graph/entities/graph_init_params.py deleted file mode 100644 index f785d58a52..0000000000 --- a/api/dify_graph/entities/graph_init_params.py +++ /dev/null @@ -1,24 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -DIFY_RUN_CONTEXT_KEY = "_dify" - - -class GraphInitParams(BaseModel): - """GraphInitParams encapsulates the configurations and contextual information - that remain constant throughout a single execution of the graph engine. - - A single execution is defined as follows: as long as the execution has not reached - its conclusion, it is considered one execution. For instance, if a workflow is suspended - and later resumed, it is still regarded as a single execution, not two. - - For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`. - """ - - # init params - workflow_id: str = Field(..., description="workflow id") - graph_config: Mapping[str, Any] = Field(..., description="graph config") - run_context: Mapping[str, Any] = Field(..., description="runtime context") - call_depth: int = Field(..., description="call depth") diff --git a/api/dify_graph/entities/pause_reason.py b/api/dify_graph/entities/pause_reason.py deleted file mode 100644 index 86d8c8ca16..0000000000 --- a/api/dify_graph/entities/pause_reason.py +++ /dev/null @@ -1,50 +0,0 @@ -from collections.abc import Mapping -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, TypeAlias - -from pydantic import BaseModel, Field - -from dify_graph.nodes.human_input.entities import FormInput, UserAction - - -class PauseReasonType(StrEnum): - HUMAN_INPUT_REQUIRED = auto() - SCHEDULED_PAUSE = auto() - - -class HumanInputRequired(BaseModel): - TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED - form_id: str - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - actions: list[UserAction] = Field(default_factory=list) - display_in_ui: bool = False - node_id: str - node_title: str - - # The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from - # `output_variable_name` to their resolved values. - # - # For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its - # selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable - # `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The - # `resolved_default_values` is `{"name": "John"}`. - # - # Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`. - resolved_default_values: Mapping[str, Any] = Field(default_factory=dict) - - # The `form_token` is the token used to submit the form via UI surfaces. It corresponds to - # `HumanInputFormRecipient.access_token`. - # - # This field is `None` if webapp delivery is not set and not - # in orchestrating mode. - form_token: str | None = None - - -class SchedulingPause(BaseModel): - TYPE: Literal[PauseReasonType.SCHEDULED_PAUSE] = PauseReasonType.SCHEDULED_PAUSE - - message: str - - -PauseReason: TypeAlias = Annotated[HumanInputRequired | SchedulingPause, Field(discriminator="TYPE")] diff --git a/api/dify_graph/entities/workflow_execution.py b/api/dify_graph/entities/workflow_execution.py deleted file mode 100644 index 459ac46415..0000000000 --- a/api/dify_graph/entities/workflow_execution.py +++ /dev/null @@ -1,74 +0,0 @@ -""" -Domain entities for workflow execution. - -Models are independent of the storage mechanism and don't contain -implementation details like tenant_id, app_id, etc. -""" - -from __future__ import annotations - -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field - -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from libs.datetime_utils import naive_utc_now - - -class WorkflowExecution(BaseModel): - """ - Domain model for workflow execution based on WorkflowRun but without - user, tenant, and app attributes. - """ - - id_: str = Field(...) - workflow_id: str = Field(...) - workflow_version: str = Field(...) - workflow_type: WorkflowType = Field(...) - graph: Mapping[str, Any] = Field(...) - - inputs: Mapping[str, Any] = Field(...) - outputs: Mapping[str, Any] | None = None - - status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING - error_message: str = Field(default="") - total_tokens: int = Field(default=0) - total_steps: int = Field(default=0) - exceptions_count: int = Field(default=0) - - started_at: datetime = Field(...) - finished_at: datetime | None = None - - @property - def elapsed_time(self) -> float: - """ - Calculate elapsed time in seconds. - If workflow is not finished, use current time. - """ - end_time = self.finished_at or naive_utc_now() - return (end_time - self.started_at).total_seconds() - - @classmethod - def new( - cls, - *, - id_: str, - workflow_id: str, - workflow_type: WorkflowType, - workflow_version: str, - graph: Mapping[str, Any], - inputs: Mapping[str, Any], - started_at: datetime, - ) -> WorkflowExecution: - return WorkflowExecution( - id_=id_, - workflow_id=workflow_id, - workflow_type=workflow_type, - workflow_version=workflow_version, - graph=graph, - inputs=inputs, - status=WorkflowExecutionStatus.RUNNING, - started_at=started_at, - ) diff --git a/api/dify_graph/entities/workflow_node_execution.py b/api/dify_graph/entities/workflow_node_execution.py deleted file mode 100644 index bc7e0d02e5..0000000000 --- a/api/dify_graph/entities/workflow_node_execution.py +++ /dev/null @@ -1,147 +0,0 @@ -""" -Domain entities for workflow node execution. - -This module contains the domain model for workflow node execution, which is used -by the core workflow module. These models are independent of the storage mechanism -and don't contain implementation details like tenant_id, app_id, etc. -""" - -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import BaseModel, Field, PrivateAttr - -from dify_graph.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus - - -class WorkflowNodeExecution(BaseModel): - """ - Domain model for workflow node execution. - - This model represents the core business entity of a node execution, - without implementation details like tenant_id, app_id, etc. - - Note: User/context-specific fields (triggered_from, created_by, created_by_role) - have been moved to the repository implementation to keep the domain model clean. - These fields are still accepted in the constructor for backward compatibility, - but they are not stored in the model. - """ - - # --------- Core identification fields --------- - - # Unique identifier for this execution record, used when persisting to storage. - # Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382'). - id: str - - # Optional secondary ID for cross-referencing purposes. - # - # NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`. - # While `node_execution_id` may sometimes be a UUID string, this is not guaranteed. - # In most scenarios, `id` should be used as the primary identifier. - node_execution_id: str | None = None - workflow_id: str # ID of the workflow this node belongs to - workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging) - # --------- Core identification fields ends --------- - - # Execution positioning and flow - index: int # Sequence number for ordering in trace visualization - predecessor_node_id: str | None = None # ID of the node that executed before this one - node_id: str # ID of the node being executed - node_type: NodeType # Type of node (e.g., start, llm, downstream response node) - title: str # Display title of the node - - # Execution data - # The `inputs` and `outputs` fields hold the full content - inputs: Mapping[str, Any] | None = None # Input variables used by this node - process_data: Mapping[str, Any] | None = None # Intermediate processing data - outputs: Mapping[str, Any] | None = None # Output variables produced by this node - - # Execution state - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status - error: str | None = None # Error message if execution failed - elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds - - # Additional metadata - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.) - - # Timing information - created_at: datetime # When execution started - finished_at: datetime | None = None # When execution completed - - _truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None) - _truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None) - - def get_truncated_inputs(self) -> Mapping[str, Any] | None: - return self._truncated_inputs - - def get_truncated_outputs(self) -> Mapping[str, Any] | None: - return self._truncated_outputs - - def get_truncated_process_data(self) -> Mapping[str, Any] | None: - return self._truncated_process_data - - def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None): - self._truncated_inputs = truncated_inputs - - def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None): - self._truncated_outputs = truncated_outputs - - def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None): - self._truncated_process_data = truncated_process_data - - def get_response_inputs(self) -> Mapping[str, Any] | None: - inputs = self.get_truncated_inputs() - if inputs: - return inputs - return self.inputs - - @property - def inputs_truncated(self): - return self._truncated_inputs is not None - - @property - def outputs_truncated(self): - return self._truncated_outputs is not None - - @property - def process_data_truncated(self): - return self._truncated_process_data is not None - - def get_response_outputs(self) -> Mapping[str, Any] | None: - outputs = self.get_truncated_outputs() - if outputs is not None: - return outputs - return self.outputs - - def get_response_process_data(self) -> Mapping[str, Any] | None: - process_data = self.get_truncated_process_data() - if process_data is not None: - return process_data - return self.process_data - - def update_from_mapping( - self, - inputs: Mapping[str, Any] | None = None, - process_data: Mapping[str, Any] | None = None, - outputs: Mapping[str, Any] | None = None, - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None, - ): - """ - Update the model from mappings. - - Args: - inputs: The inputs to update - process_data: The process data to update - outputs: The outputs to update - metadata: The metadata to update - """ - if inputs is not None: - self.inputs = dict(inputs) - if process_data is not None: - self.process_data = dict(process_data) - if outputs is not None: - self.outputs = dict(outputs) - if metadata is not None: - self.metadata = dict(metadata) diff --git a/api/dify_graph/entities/workflow_start_reason.py b/api/dify_graph/entities/workflow_start_reason.py deleted file mode 100644 index df0f75383b..0000000000 --- a/api/dify_graph/entities/workflow_start_reason.py +++ /dev/null @@ -1,8 +0,0 @@ -from enum import StrEnum - - -class WorkflowStartReason(StrEnum): - """Reason for workflow start events across graph/queue/SSE layers.""" - - INITIAL = "initial" # First start of a workflow run. - RESUMPTION = "resumption" # Start triggered after resuming a paused run. diff --git a/api/dify_graph/enums.py b/api/dify_graph/enums.py deleted file mode 100644 index cfb135cbb0..0000000000 --- a/api/dify_graph/enums.py +++ /dev/null @@ -1,286 +0,0 @@ -from enum import StrEnum -from typing import ClassVar, TypeAlias - - -class NodeState(StrEnum): - """State of a node or edge during workflow execution.""" - - UNKNOWN = "unknown" - TAKEN = "taken" - SKIPPED = "skipped" - - -class SystemVariableKey(StrEnum): - """ - System Variables. - """ - - QUERY = "query" - FILES = "files" - CONVERSATION_ID = "conversation_id" - USER_ID = "user_id" - DIALOGUE_COUNT = "dialogue_count" - APP_ID = "app_id" - WORKFLOW_ID = "workflow_id" - WORKFLOW_EXECUTION_ID = "workflow_run_id" - TIMESTAMP = "timestamp" - # RAG Pipeline - DOCUMENT_ID = "document_id" - ORIGINAL_DOCUMENT_ID = "original_document_id" - BATCH = "batch" - DATASET_ID = "dataset_id" - DATASOURCE_TYPE = "datasource_type" - DATASOURCE_INFO = "datasource_info" - INVOKE_FROM = "invoke_from" - - -NodeType: TypeAlias = str - - -class BuiltinNodeTypes: - """Built-in node type string constants. - - `node_type` values are plain strings throughout the graph runtime. This namespace - only exposes the built-in values shipped by `dify_graph`; downstream packages can - use additional strings without extending this class. - """ - - START: ClassVar[NodeType] = "start" - END: ClassVar[NodeType] = "end" - ANSWER: ClassVar[NodeType] = "answer" - LLM: ClassVar[NodeType] = "llm" - KNOWLEDGE_RETRIEVAL: ClassVar[NodeType] = "knowledge-retrieval" - IF_ELSE: ClassVar[NodeType] = "if-else" - CODE: ClassVar[NodeType] = "code" - TEMPLATE_TRANSFORM: ClassVar[NodeType] = "template-transform" - QUESTION_CLASSIFIER: ClassVar[NodeType] = "question-classifier" - HTTP_REQUEST: ClassVar[NodeType] = "http-request" - TOOL: ClassVar[NodeType] = "tool" - DATASOURCE: ClassVar[NodeType] = "datasource" - VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-aggregator" - LEGACY_VARIABLE_AGGREGATOR: ClassVar[NodeType] = "variable-assigner" - LOOP: ClassVar[NodeType] = "loop" - LOOP_START: ClassVar[NodeType] = "loop-start" - LOOP_END: ClassVar[NodeType] = "loop-end" - ITERATION: ClassVar[NodeType] = "iteration" - ITERATION_START: ClassVar[NodeType] = "iteration-start" - PARAMETER_EXTRACTOR: ClassVar[NodeType] = "parameter-extractor" - VARIABLE_ASSIGNER: ClassVar[NodeType] = "assigner" - DOCUMENT_EXTRACTOR: ClassVar[NodeType] = "document-extractor" - LIST_OPERATOR: ClassVar[NodeType] = "list-operator" - AGENT: ClassVar[NodeType] = "agent" - HUMAN_INPUT: ClassVar[NodeType] = "human-input" - - -BUILT_IN_NODE_TYPES: tuple[NodeType, ...] = ( - BuiltinNodeTypes.START, - BuiltinNodeTypes.END, - BuiltinNodeTypes.ANSWER, - BuiltinNodeTypes.LLM, - BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL, - BuiltinNodeTypes.IF_ELSE, - BuiltinNodeTypes.CODE, - BuiltinNodeTypes.TEMPLATE_TRANSFORM, - BuiltinNodeTypes.QUESTION_CLASSIFIER, - BuiltinNodeTypes.HTTP_REQUEST, - BuiltinNodeTypes.TOOL, - BuiltinNodeTypes.DATASOURCE, - BuiltinNodeTypes.VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LEGACY_VARIABLE_AGGREGATOR, - BuiltinNodeTypes.LOOP, - BuiltinNodeTypes.LOOP_START, - BuiltinNodeTypes.LOOP_END, - BuiltinNodeTypes.ITERATION, - BuiltinNodeTypes.ITERATION_START, - BuiltinNodeTypes.PARAMETER_EXTRACTOR, - BuiltinNodeTypes.VARIABLE_ASSIGNER, - BuiltinNodeTypes.DOCUMENT_EXTRACTOR, - BuiltinNodeTypes.LIST_OPERATOR, - BuiltinNodeTypes.AGENT, - BuiltinNodeTypes.HUMAN_INPUT, -) - - -class NodeExecutionType(StrEnum): - """Node execution type classification.""" - - EXECUTABLE = "executable" # Regular nodes that execute and produce outputs - RESPONSE = "response" # Response nodes that stream outputs (Answer, End) - BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier) - CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph) - ROOT = "root" # Nodes that can serve as execution entry points - - -class ErrorStrategy(StrEnum): - FAIL_BRANCH = "fail-branch" - DEFAULT_VALUE = "default-value" - - -class FailBranchSourceHandle(StrEnum): - FAILED = "fail-branch" - SUCCESS = "success-branch" - - -class WorkflowType(StrEnum): - """ - Workflow Type Enum for domain layer - """ - - WORKFLOW = "workflow" - CHAT = "chat" - RAG_PIPELINE = "rag-pipeline" - - -class WorkflowExecutionStatus(StrEnum): - # State diagram for the workflw status: - # (@) means start, (*) means end - # - # ┌------------------>------------------------->------------------->--------------┐ - # | | - # | ┌-----------------------<--------------------┐ | - # ^ | | | - # | | ^ | - # | V | | - # ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V - # | Scheduled |------->| Running |---------------------->| paused | | - # └-----------┘ └-----------------------┘ └-----------┘ | - # | | | | | | | - # | | | | | | | - # ^ | | | V V | - # | | | | | ┌---------┐ | - # (@) | | | └------------------------>| Stopped |<----┘ - # | | | └---------┘ - # | | | | - # | | V V - # | | ┌-----------┐ | - # | | | Succeeded |------------->--------------┤ - # | | └-----------┘ | - # | V V - # | +--------┐ | - # | | Failed |---------------------->----------------┤ - # | └--------┘ | - # V V - # ┌---------------------┐ | - # | Partially Succeeded |---------------------->-----------------┘--------> (*) - # └---------------------┘ - # - # Mermaid diagram: - # - # --- - # title: State diagram for Workflow run state - # --- - # stateDiagram-v2 - # scheduled: Scheduled - # running: Running - # succeeded: Succeeded - # failed: Failed - # partial_succeeded: Partial Succeeded - # paused: Paused - # stopped: Stopped - # - # [*] --> scheduled: - # scheduled --> running: Start Execution - # running --> paused: Human input required - # paused --> running: human input added - # paused --> stopped: User stops execution - # running --> succeeded: Execution finishes without any error - # running --> failed: Execution finishes with errors - # running --> stopped: User stops execution - # running --> partial_succeeded: some execution occurred and handled during execution - # - # scheduled --> stopped: User stops execution - # - # succeeded --> [*] - # failed --> [*] - # partial_succeeded --> [*] - # stopped --> [*] - - # `SCHEDULED` means that the workflow is scheduled to run, but has not - # started running yet. (maybe due to possible worker saturation.) - # - # This enum value is currently unused. - SCHEDULED = "scheduled" - - # `RUNNING` means the workflow is exeuting. - RUNNING = "running" - - # `SUCCEEDED` means the execution of workflow succeed without any error. - SUCCEEDED = "succeeded" - - # `FAILED` means the execution of workflow failed without some errors. - FAILED = "failed" - - # `STOPPED` means the execution of workflow was stopped, either manually - # by the user, or automatically by the Dify application (E.G. the moderation - # mechanism.) - STOPPED = "stopped" - - # `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow - # execution, but they were successfully handled (e.g., by using an error - # strategy such as "fail branch" or "default value"). - PARTIAL_SUCCEEDED = "partial-succeeded" - - # `PAUSED` indicates that the workflow execution is temporarily paused - # (e.g., awaiting human input) and is expected to resume later. - PAUSED = "paused" - - def is_ended(self) -> bool: - return self in _END_STATE - - @classmethod - def ended_values(cls) -> list[str]: - return [status.value for status in _END_STATE] - - -_END_STATE = frozenset( - [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] -) - - -class WorkflowNodeExecutionMetadataKey(StrEnum): - """ - Node Run Metadata Key. - - Values in this enum are persisted as execution metadata and must stay in sync - with every node that writes `NodeRunResult.metadata`. - """ - - TOTAL_TOKENS = "total_tokens" - TOTAL_PRICE = "total_price" - CURRENCY = "currency" - TOOL_INFO = "tool_info" - AGENT_LOG = "agent_log" - ITERATION_ID = "iteration_id" - ITERATION_INDEX = "iteration_index" - LOOP_ID = "loop_id" - LOOP_INDEX = "loop_index" - PARALLEL_ID = "parallel_id" - PARALLEL_START_NODE_ID = "parallel_start_node_id" - PARENT_PARALLEL_ID = "parent_parallel_id" - PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id" - PARALLEL_MODE_RUN_ID = "parallel_mode_run_id" - ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs - LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs - ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field - LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output - DATASOURCE_INFO = "datasource_info" - TRIGGER_INFO = "trigger_info" - COMPLETED_REASON = "completed_reason" # completed reason for loop node - - -class WorkflowNodeExecutionStatus(StrEnum): - PENDING = "pending" # Node is scheduled but not yet executing - RUNNING = "running" - SUCCEEDED = "succeeded" - FAILED = "failed" - EXCEPTION = "exception" - STOPPED = "stopped" - PAUSED = "paused" - - # Legacy statuses - kept for backward compatibility - RETRY = "retry" # Legacy: replaced by retry mechanism in error handling diff --git a/api/dify_graph/errors.py b/api/dify_graph/errors.py deleted file mode 100644 index 463d17713e..0000000000 --- a/api/dify_graph/errors.py +++ /dev/null @@ -1,16 +0,0 @@ -from dify_graph.nodes.base.node import Node - - -class WorkflowNodeRunFailedError(Exception): - def __init__(self, node: Node, err_msg: str): - self._node = node - self._error = err_msg - super().__init__(f"Node {node.title} run failed: {err_msg}") - - @property - def node(self) -> Node: - return self._node - - @property - def error(self) -> str: - return self._error diff --git a/api/dify_graph/file/__init__.py b/api/dify_graph/file/__init__.py deleted file mode 100644 index 44749ebec3..0000000000 --- a/api/dify_graph/file/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -from .constants import FILE_MODEL_IDENTITY -from .enums import ArrayFileAttribute, FileAttribute, FileBelongsTo, FileTransferMethod, FileType -from .models import ( - File, - FileUploadConfig, - ImageConfig, -) - -__all__ = [ - "FILE_MODEL_IDENTITY", - "ArrayFileAttribute", - "File", - "FileAttribute", - "FileBelongsTo", - "FileTransferMethod", - "FileType", - "FileUploadConfig", - "ImageConfig", -] diff --git a/api/dify_graph/file/constants.py b/api/dify_graph/file/constants.py deleted file mode 100644 index 0665ed7e0d..0000000000 --- a/api/dify_graph/file/constants.py +++ /dev/null @@ -1,11 +0,0 @@ -from typing import Any - -# TODO(QuantumGhost): Refactor variable type identification. Instead of directly -# comparing `dify_model_identity` with constants throughout the codebase, extract -# this logic into a dedicated function. This would encapsulate the implementation -# details of how different variable types are identified. -FILE_MODEL_IDENTITY = "__dify__file__" - - -def maybe_file_object(o: Any) -> bool: - return isinstance(o, dict) and o.get("dify_model_identity") == FILE_MODEL_IDENTITY diff --git a/api/dify_graph/file/enums.py b/api/dify_graph/file/enums.py deleted file mode 100644 index 170eb4fc23..0000000000 --- a/api/dify_graph/file/enums.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum - - -class FileType(StrEnum): - IMAGE = "image" - DOCUMENT = "document" - AUDIO = "audio" - VIDEO = "video" - CUSTOM = "custom" - - @staticmethod - def value_of(value): - for member in FileType: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileTransferMethod(StrEnum): - REMOTE_URL = "remote_url" - LOCAL_FILE = "local_file" - TOOL_FILE = "tool_file" - DATASOURCE_FILE = "datasource_file" - - @staticmethod - def value_of(value): - for member in FileTransferMethod: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileBelongsTo(StrEnum): - USER = "user" - ASSISTANT = "assistant" - - @staticmethod - def value_of(value): - for member in FileBelongsTo: - if member.value == value: - return member - raise ValueError(f"No matching enum found for value '{value}'") - - -class FileAttribute(StrEnum): - TYPE = "type" - SIZE = "size" - NAME = "name" - MIME_TYPE = "mime_type" - TRANSFER_METHOD = "transfer_method" - URL = "url" - EXTENSION = "extension" - RELATED_ID = "related_id" - - -class ArrayFileAttribute(StrEnum): - LENGTH = "length" diff --git a/api/dify_graph/file/file_manager.py b/api/dify_graph/file/file_manager.py deleted file mode 100644 index 8d998054db..0000000000 --- a/api/dify_graph/file/file_manager.py +++ /dev/null @@ -1,143 +0,0 @@ -from __future__ import annotations - -import base64 -from collections.abc import Mapping - -from dify_graph.model_runtime.entities import ( - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - TextPromptMessageContent, - VideoPromptMessageContent, -) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes - -from . import helpers -from .enums import FileAttribute -from .models import File, FileTransferMethod, FileType -from .runtime import get_workflow_file_runtime - - -def get_attr(*, file: File, attr: FileAttribute): - match attr: - case FileAttribute.TYPE: - return file.type.value - case FileAttribute.SIZE: - return file.size - case FileAttribute.NAME: - return file.filename - case FileAttribute.MIME_TYPE: - return file.mime_type - case FileAttribute.TRANSFER_METHOD: - return file.transfer_method.value - case FileAttribute.URL: - return _to_url(file) - case FileAttribute.EXTENSION: - return file.extension - case FileAttribute.RELATED_ID: - return file.related_id - - -def to_prompt_message_content( - f: File, - /, - *, - image_detail_config: ImagePromptMessageContent.DETAIL | None = None, -) -> PromptMessageContentUnionTypes: - """Convert a file to prompt message content.""" - if f.extension is None: - raise ValueError("Missing file extension") - if f.mime_type is None: - raise ValueError("Missing file mime_type") - - prompt_class_map: Mapping[FileType, type[PromptMessageContentUnionTypes]] = { - FileType.IMAGE: ImagePromptMessageContent, - FileType.AUDIO: AudioPromptMessageContent, - FileType.VIDEO: VideoPromptMessageContent, - FileType.DOCUMENT: DocumentPromptMessageContent, - } - - if f.type not in prompt_class_map: - return TextPromptMessageContent(data=f"[Unsupported file type: {f.filename} ({f.type.value})]") - - send_format = get_workflow_file_runtime().multimodal_send_format - params = { - "base64_data": _get_encoded_string(f) if send_format == "base64" else "", - "url": _to_url(f) if send_format == "url" else "", - "format": f.extension.removeprefix("."), - "mime_type": f.mime_type, - "filename": f.filename or "", - } - if f.type == FileType.IMAGE: - params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW - - return prompt_class_map[f.type].model_validate(params) - - -def download(f: File, /) -> bytes: - if f.transfer_method in ( - FileTransferMethod.TOOL_FILE, - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.DATASOURCE_FILE, - ): - return _download_file_content(f.storage_key) - elif f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - return response.content - raise ValueError(f"unsupported transfer method: {f.transfer_method}") - - -def _download_file_content(path: str, /) -> bytes: - """Download and return a file from storage as bytes.""" - data = get_workflow_file_runtime().storage_load(path, stream=False) - if not isinstance(data, bytes): - raise ValueError(f"file {path} is not a bytes object") - return data - - -def _get_encoded_string(f: File, /) -> str: - match f.transfer_method: - case FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - response = get_workflow_file_runtime().http_get(f.remote_url, follow_redirects=True) - response.raise_for_status() - data = response.content - case FileTransferMethod.LOCAL_FILE: - data = _download_file_content(f.storage_key) - case FileTransferMethod.TOOL_FILE: - data = _download_file_content(f.storage_key) - case FileTransferMethod.DATASOURCE_FILE: - data = _download_file_content(f.storage_key) - - return base64.b64encode(data).decode("utf-8") - - -def _to_url(f: File, /): - if f.transfer_method == FileTransferMethod.REMOTE_URL: - if f.remote_url is None: - raise ValueError("Missing file remote_url") - return f.remote_url - elif f.transfer_method == FileTransferMethod.LOCAL_FILE: - if f.related_id is None: - raise ValueError("Missing file related_id") - return f.remote_url or helpers.get_signed_file_url(upload_file_id=f.related_id) - elif f.transfer_method == FileTransferMethod.TOOL_FILE: - if f.related_id is None or f.extension is None: - raise ValueError("Missing file related_id or extension") - return helpers.get_signed_tool_file_url(tool_file_id=f.related_id, extension=f.extension) - else: - raise ValueError(f"Unsupported transfer method: {f.transfer_method}") - - -class FileManager: - """Adapter exposing file manager helpers behind FileManagerProtocol.""" - - def download(self, f: File, /) -> bytes: - return download(f) - - -file_manager = FileManager() diff --git a/api/dify_graph/file/helpers.py b/api/dify_graph/file/helpers.py deleted file mode 100644 index 310cb1310b..0000000000 --- a/api/dify_graph/file/helpers.py +++ /dev/null @@ -1,92 +0,0 @@ -from __future__ import annotations - -import base64 -import hashlib -import hmac -import os -import time -import urllib.parse - -from .runtime import get_workflow_file_runtime - - -def get_signed_file_url(upload_file_id: str, as_attachment: bool = False, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - base_url = runtime.files_url if for_external else (runtime.internal_files_url or runtime.files_url) - url = f"{base_url}/files/{upload_file_id}/file-preview" - - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - query: dict[str, str] = {"timestamp": timestamp, "nonce": nonce, "sign": encoded_sign} - if as_attachment: - query["as_attachment"] = "true" - query_string = urllib.parse.urlencode(query) - - return f"{url}?{query_string}" - - -def get_signed_file_url_for_plugin(filename: str, mimetype: str, tenant_id: str, user_id: str) -> str: - runtime = get_workflow_file_runtime() - # Plugin access should use internal URL for Docker network communication. - base_url = runtime.internal_files_url or runtime.files_url - url = f"{base_url}/files/upload/for-plugin" - timestamp = str(int(time.time())) - nonce = os.urandom(16).hex() - key = runtime.secret_key.encode() - msg = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - sign = hmac.new(key, msg.encode(), hashlib.sha256).digest() - encoded_sign = base64.urlsafe_b64encode(sign).decode() - return f"{url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}&user_id={user_id}&tenant_id={tenant_id}" - - -def get_signed_tool_file_url(tool_file_id: str, extension: str, for_external: bool = True) -> str: - runtime = get_workflow_file_runtime() - return runtime.sign_tool_file(tool_file_id=tool_file_id, extension=extension, for_external=for_external) - - -def verify_plugin_file_signature( - *, filename: str, mimetype: str, tenant_id: str, user_id: str, timestamp: str, nonce: str, sign: str -) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"upload|{filename}|{mimetype}|{tenant_id}|{user_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout - - -def verify_image_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout - - -def verify_file_signature(*, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool: - runtime = get_workflow_file_runtime() - data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" - secret_key = runtime.secret_key.encode() - recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() - recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode() - - if sign != recalculated_encoded_sign: - return False - - current_time = int(time.time()) - return current_time - int(timestamp) <= runtime.files_access_timeout diff --git a/api/dify_graph/file/models.py b/api/dify_graph/file/models.py deleted file mode 100644 index dcba00978e..0000000000 --- a/api/dify_graph/file/models.py +++ /dev/null @@ -1,197 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from typing import Any -from uuid import UUID, uuid4 - -from pydantic import BaseModel, Field, model_validator - -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent - -from . import helpers -from .constants import FILE_MODEL_IDENTITY -from .enums import FileTransferMethod, FileType - - -def sign_tool_file(*, tool_file_id: str, extension: str, for_external: bool = True) -> str: - """Compatibility shim for tests and legacy callers patching ``models.sign_tool_file``.""" - return helpers.get_signed_tool_file_url( - tool_file_id=tool_file_id, - extension=extension, - for_external=for_external, - ) - - -class ImageConfig(BaseModel): - """ - NOTE: This part of validation is deprecated, but still used in app features "Image Upload". - """ - - number_limits: int = 0 - transfer_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - detail: ImagePromptMessageContent.DETAIL | None = None - - -class FileUploadConfig(BaseModel): - """ - File Upload Entity. - """ - - image_config: ImageConfig | None = None - allowed_file_types: Sequence[FileType] = Field(default_factory=list) - allowed_file_extensions: Sequence[str] = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] = Field(default_factory=list) - number_limits: int = 0 - - -class ToolFile(BaseModel): - id: UUID = Field(default_factory=uuid4, description="Unique identifier for the file") - user_id: UUID = Field(..., description="ID of the user who owns this file") - tenant_id: UUID = Field(..., description="ID of the tenant/organization") - conversation_id: UUID | None = Field(None, description="ID of the associated conversation") - file_key: str = Field(..., max_length=255, description="Storage key for the file") - mimetype: str = Field(..., max_length=255, description="MIME type of the file") - original_url: str | None = Field( - None, max_length=2048, description="Original URL if file was fetched from external source" - ) - name: str = Field(default="", max_length=255, description="Display name of the file") - size: int = Field(default=-1, ge=-1, description="File size in bytes (-1 if unknown)") - - class Config: - from_attributes = True # Enable ORM mode for SQLAlchemy compatibility - populate_by_name = True - - -class File(BaseModel): - # NOTE: dify_model_identity is a special identifier used to distinguish between - # new and old data formats during serialization and deserialization. - dify_model_identity: str = FILE_MODEL_IDENTITY - - id: str | None = None # message file id - tenant_id: str - type: FileType - transfer_method: FileTransferMethod - # If `transfer_method` is `FileTransferMethod.remote_url`, the - # `remote_url` attribute must not be `None`. - remote_url: str | None = None # remote url - # If `transfer_method` is `FileTransferMethod.local_file` or - # `FileTransferMethod.tool_file`, the `related_id` attribute must not be `None`. - # - # It should be set to `ToolFile.id` when `transfer_method` is `tool_file`. - related_id: str | None = None - filename: str | None = None - extension: str | None = Field(default=None, description="File extension, should contain dot") - mime_type: str | None = None - size: int = -1 - - # Those properties are private, should not be exposed to the outside. - _storage_key: str - - def __init__( - self, - *, - id: str | None = None, - tenant_id: str, - type: FileType, - transfer_method: FileTransferMethod, - remote_url: str | None = None, - related_id: str | None = None, - filename: str | None = None, - extension: str | None = None, - mime_type: str | None = None, - size: int = -1, - storage_key: str | None = None, - dify_model_identity: str | None = FILE_MODEL_IDENTITY, - url: str | None = None, - # Legacy compatibility fields - explicitly handle known extra fields - tool_file_id: str | None = None, - upload_file_id: str | None = None, - datasource_file_id: str | None = None, - ): - super().__init__( - id=id, - tenant_id=tenant_id, - type=type, - transfer_method=transfer_method, - remote_url=remote_url, - related_id=related_id, - filename=filename, - extension=extension, - mime_type=mime_type, - size=size, - dify_model_identity=dify_model_identity, - url=url, - ) - self._storage_key = str(storage_key) - - def to_dict(self) -> Mapping[str, str | int | None]: - data = self.model_dump(mode="json") - return { - **data, - "url": self.generate_url(), - } - - @property - def markdown(self) -> str: - url = self.generate_url() - if self.type == FileType.IMAGE: - text = f"![{self.filename or ''}]({url})" - else: - text = f"[{self.filename or url}]({url})" - - return text - - def generate_url(self, for_external: bool = True) -> str | None: - if self.transfer_method == FileTransferMethod.REMOTE_URL: - return self.remote_url - elif self.transfer_method == FileTransferMethod.LOCAL_FILE: - if self.related_id is None: - raise ValueError("Missing file related_id") - return helpers.get_signed_file_url(upload_file_id=self.related_id, for_external=for_external) - elif self.transfer_method in [FileTransferMethod.TOOL_FILE, FileTransferMethod.DATASOURCE_FILE]: - assert self.related_id is not None - assert self.extension is not None - return sign_tool_file( - tool_file_id=self.related_id, - extension=self.extension, - for_external=for_external, - ) - return None - - def to_plugin_parameter(self) -> dict[str, Any]: - return { - "dify_model_identity": FILE_MODEL_IDENTITY, - "mime_type": self.mime_type, - "filename": self.filename, - "extension": self.extension, - "size": self.size, - "type": self.type, - "url": self.generate_url(for_external=False), - } - - @model_validator(mode="after") - def validate_after(self) -> File: - match self.transfer_method: - case FileTransferMethod.REMOTE_URL: - if not self.remote_url: - raise ValueError("Missing file url") - if not isinstance(self.remote_url, str) or not self.remote_url.startswith("http"): - raise ValueError("Invalid file url") - case FileTransferMethod.LOCAL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") - case FileTransferMethod.TOOL_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") - case FileTransferMethod.DATASOURCE_FILE: - if not self.related_id: - raise ValueError("Missing file related_id") - return self - - @property - def storage_key(self) -> str: - return self._storage_key - - @storage_key.setter - def storage_key(self, value: str) -> None: - self._storage_key = value diff --git a/api/dify_graph/file/protocols.py b/api/dify_graph/file/protocols.py deleted file mode 100644 index 24cbb42735..0000000000 --- a/api/dify_graph/file/protocols.py +++ /dev/null @@ -1,43 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import Protocol - - -class HttpResponseProtocol(Protocol): - """Subset of response behavior needed by workflow file helpers.""" - - @property - def content(self) -> bytes: ... - - def raise_for_status(self) -> object: ... - - -class WorkflowFileRuntimeProtocol(Protocol): - """Runtime dependencies required by ``dify_graph.file``. - - Implementations are expected to be provided by integration layers (for example, - ``core.app.workflow.file_runtime``) so the workflow package avoids importing - application infrastructure modules directly. - """ - - @property - def files_url(self) -> str: ... - - @property - def internal_files_url(self) -> str | None: ... - - @property - def secret_key(self) -> str: ... - - @property - def files_access_timeout(self) -> int: ... - - @property - def multimodal_send_format(self) -> str: ... - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: ... - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: ... - - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: ... diff --git a/api/dify_graph/file/runtime.py b/api/dify_graph/file/runtime.py deleted file mode 100644 index 94253e0255..0000000000 --- a/api/dify_graph/file/runtime.py +++ /dev/null @@ -1,58 +0,0 @@ -from __future__ import annotations - -from collections.abc import Generator -from typing import NoReturn - -from .protocols import HttpResponseProtocol, WorkflowFileRuntimeProtocol - - -class WorkflowFileRuntimeNotConfiguredError(RuntimeError): - """Raised when workflow file runtime dependencies were not configured.""" - - -class _UnconfiguredWorkflowFileRuntime(WorkflowFileRuntimeProtocol): - def _raise(self) -> NoReturn: - raise WorkflowFileRuntimeNotConfiguredError( - "workflow file runtime is not configured, call set_workflow_file_runtime(...) first" - ) - - @property - def files_url(self) -> str: - self._raise() - - @property - def internal_files_url(self) -> str | None: - self._raise() - - @property - def secret_key(self) -> str: - self._raise() - - @property - def files_access_timeout(self) -> int: - self._raise() - - @property - def multimodal_send_format(self) -> str: - self._raise() - - def http_get(self, url: str, *, follow_redirects: bool = True) -> HttpResponseProtocol: - self._raise() - - def storage_load(self, path: str, *, stream: bool = False) -> bytes | Generator: - self._raise() - - def sign_tool_file(self, *, tool_file_id: str, extension: str, for_external: bool = True) -> str: - self._raise() - - -_runtime: WorkflowFileRuntimeProtocol = _UnconfiguredWorkflowFileRuntime() - - -def set_workflow_file_runtime(runtime: WorkflowFileRuntimeProtocol) -> None: - global _runtime - _runtime = runtime - - -def get_workflow_file_runtime() -> WorkflowFileRuntimeProtocol: - return _runtime diff --git a/api/dify_graph/file/tool_file_parser.py b/api/dify_graph/file/tool_file_parser.py deleted file mode 100644 index 2d7a3d43df..0000000000 --- a/api/dify_graph/file/tool_file_parser.py +++ /dev/null @@ -1,9 +0,0 @@ -from collections.abc import Callable -from typing import Any - -_tool_file_manager_factory: Callable[[], Any] | None = None - - -def set_tool_file_manager_factory(factory: Callable[[], Any]): - global _tool_file_manager_factory - _tool_file_manager_factory = factory diff --git a/api/dify_graph/graph/__init__.py b/api/dify_graph/graph/__init__.py deleted file mode 100644 index 4830ea83d3..0000000000 --- a/api/dify_graph/graph/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from .edge import Edge -from .graph import Graph, GraphBuilder, NodeFactory -from .graph_template import GraphTemplate - -__all__ = [ - "Edge", - "Graph", - "GraphBuilder", - "GraphTemplate", - "NodeFactory", -] diff --git a/api/dify_graph/graph/edge.py b/api/dify_graph/graph/edge.py deleted file mode 100644 index f4f67ea6be..0000000000 --- a/api/dify_graph/graph/edge.py +++ /dev/null @@ -1,15 +0,0 @@ -import uuid -from dataclasses import dataclass, field - -from dify_graph.enums import NodeState - - -@dataclass -class Edge: - """Edge connecting two nodes in a workflow graph.""" - - id: str = field(default_factory=lambda: str(uuid.uuid4())) - tail: str = "" # tail node id (source) - head: str = "" # head node id (target) - source_handle: str = "source" # source handle for conditional branching - state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state diff --git a/api/dify_graph/graph/graph.py b/api/dify_graph/graph/graph.py deleted file mode 100644 index 85117583e0..0000000000 --- a/api/dify_graph/graph/graph.py +++ /dev/null @@ -1,439 +0,0 @@ -from __future__ import annotations - -import logging -from collections import defaultdict -from collections.abc import Mapping, Sequence -from typing import Protocol, cast, final - -from pydantic import TypeAdapter - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState -from dify_graph.nodes.base.node import Node -from libs.typing import is_str - -from .edge import Edge -from .validation import get_graph_validator - -logger = logging.getLogger(__name__) - -_ListNodeConfigDict = TypeAdapter(list[NodeConfigDict]) - - -class NodeFactory(Protocol): - """ - Protocol for creating Node instances from node data dictionaries. - - This protocol decouples the Graph class from specific node mapping implementations, - allowing for different node creation strategies while maintaining type safety. - """ - - def create_node(self, node_config: NodeConfigDict) -> Node: - """ - Create a Node instance from node configuration data. - - :param node_config: node configuration dictionary containing type and other data - :return: initialized Node instance - :raises ValueError: if node type is unknown or no implementation exists for the resolved version - :raises ValidationError: if node_config does not satisfy NodeConfigDict/BaseNodeData validation - """ - ... - - -@final -class Graph: - """Graph representation with nodes and edges for workflow execution.""" - - def __init__( - self, - *, - nodes: dict[str, Node] | None = None, - edges: dict[str, Edge] | None = None, - in_edges: dict[str, list[str]] | None = None, - out_edges: dict[str, list[str]] | None = None, - root_node: Node, - ): - """ - Initialize Graph instance. - - :param nodes: graph nodes mapping (node id: node object) - :param edges: graph edges mapping (edge id: edge object) - :param in_edges: incoming edges mapping (node id: list of edge ids) - :param out_edges: outgoing edges mapping (node id: list of edge ids) - :param root_node: root node object - """ - self.nodes = nodes or {} - self.edges = edges or {} - self.in_edges = in_edges or {} - self.out_edges = out_edges or {} - self.root_node = root_node - - @classmethod - def _parse_node_configs(cls, node_configs: list[NodeConfigDict]) -> dict[str, NodeConfigDict]: - """ - Parse node configurations and build a mapping of node IDs to configs. - - :param node_configs: list of node configuration dictionaries - :return: mapping of node ID to node config - """ - node_configs_map: dict[str, NodeConfigDict] = {} - - for node_config in node_configs: - node_configs_map[node_config["id"]] = node_config - - return node_configs_map - - @classmethod - def _build_edges( - cls, edge_configs: list[dict[str, object]] - ) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]: - """ - Build edge objects and mappings from edge configurations. - - :param edge_configs: list of edge configurations - :return: tuple of (edges dict, in_edges dict, out_edges dict) - """ - edges: dict[str, Edge] = {} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - edge_counter = 0 - for edge_config in edge_configs: - source = edge_config.get("source") - target = edge_config.get("target") - - if not is_str(source) or not is_str(target): - continue - - # Create edge - edge_id = f"edge_{edge_counter}" - edge_counter += 1 - - source_handle = edge_config.get("sourceHandle", "source") - if not is_str(source_handle): - continue - - edge = Edge( - id=edge_id, - tail=source, - head=target, - source_handle=source_handle, - ) - - edges[edge_id] = edge - out_edges[source].append(edge_id) - in_edges[target].append(edge_id) - - return edges, dict(in_edges), dict(out_edges) - - @classmethod - def _create_node_instances( - cls, - node_configs_map: dict[str, NodeConfigDict], - node_factory: NodeFactory, - ) -> dict[str, Node]: - """ - Create node instances from configurations using the node factory. - - :param node_configs_map: mapping of node ID to node config - :param node_factory: factory for creating node instances - :return: mapping of node ID to node instance - """ - nodes: dict[str, Node] = {} - - for node_id, node_config in node_configs_map.items(): - try: - node_instance = node_factory.create_node(node_config) - except Exception: - logger.exception("Failed to create node instance for node_id %s", node_id) - raise - nodes[node_id] = node_instance - - return nodes - - @classmethod - def new(cls) -> GraphBuilder: - """Create a fluent builder for assembling a graph programmatically.""" - - return GraphBuilder(graph_cls=cls) - - @staticmethod - def _filter_canvas_only_nodes(node_configs: Sequence[Mapping[str, object]]) -> list[dict[str, object]]: - """ - Remove editor-only nodes before `NodeConfigDict` validation. - - Persisted note widgets use a top-level `type == "custom-note"` but leave - `data.type` empty because they are never executable graph nodes. Filter - them while configs are still raw dicts so Pydantic does not validate - their placeholder payloads against `BaseNodeData.type: NodeType`. - """ - filtered_node_configs: list[dict[str, object]] = [] - for node_config in node_configs: - if node_config.get("type", "") == "custom-note": - continue - filtered_node_configs.append(dict(node_config)) - return filtered_node_configs - - @classmethod - def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None: - """ - Promote nodes configured with FAIL_BRANCH error strategy to branch execution type. - - :param nodes: mapping of node ID to node instance - """ - for node in nodes.values(): - if node.error_strategy == ErrorStrategy.FAIL_BRANCH: - node.execution_type = NodeExecutionType.BRANCH - - @classmethod - def _mark_inactive_root_branches( - cls, - nodes: dict[str, Node], - edges: dict[str, Edge], - in_edges: dict[str, list[str]], - out_edges: dict[str, list[str]], - active_root_id: str, - ) -> None: - """ - Mark nodes and edges from inactive root branches as skipped. - - Algorithm: - 1. Mark inactive root nodes as skipped - 2. For skipped nodes, mark all their outgoing edges as skipped - 3. For each edge marked as skipped, check its target node: - - If ALL incoming edges are skipped, mark the node as skipped - - Otherwise, leave the node state unchanged - - :param nodes: mapping of node ID to node instance - :param edges: mapping of edge ID to edge instance - :param in_edges: mapping of node ID to incoming edge IDs - :param out_edges: mapping of node ID to outgoing edge IDs - :param active_root_id: ID of the active root node - """ - # Find all top-level root nodes (nodes with ROOT execution type and no incoming edges) - top_level_roots: list[str] = [ - node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT - ] - - # If there's only one root or the active root is not a top-level root, no marking needed - if len(top_level_roots) <= 1 or active_root_id not in top_level_roots: - return - - # Mark inactive root nodes as skipped - inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id] - for root_id in inactive_roots: - if root_id in nodes: - nodes[root_id].state = NodeState.SKIPPED - - # Recursively mark downstream nodes and edges - def mark_downstream(node_id: str) -> None: - """Recursively mark downstream nodes and edges as skipped.""" - if nodes[node_id].state != NodeState.SKIPPED: - return - # If this node is skipped, mark all its outgoing edges as skipped - out_edge_ids = out_edges.get(node_id, []) - for edge_id in out_edge_ids: - edge = edges[edge_id] - edge.state = NodeState.SKIPPED - - # Check the target node of this edge - target_node = nodes[edge.head] - in_edge_ids = in_edges.get(target_node.id, []) - in_edge_states = [edges[eid].state for eid in in_edge_ids] - - # If all incoming edges are skipped, mark the node as skipped - if all(state == NodeState.SKIPPED for state in in_edge_states): - target_node.state = NodeState.SKIPPED - # Recursively process downstream nodes - mark_downstream(target_node.id) - - # Process each inactive root and its downstream nodes - for root_id in inactive_roots: - mark_downstream(root_id) - - @classmethod - def init( - cls, - *, - graph_config: Mapping[str, object], - node_factory: NodeFactory, - root_node_id: str, - skip_validation: bool = False, - ) -> Graph: - """ - Initialize a graph with an explicit execution entry point. - - :param graph_config: graph config containing nodes and edges - :param node_factory: factory for creating node instances from config data - :param root_node_id: active root node id - :return: graph instance - """ - # Parse configs - edge_configs = graph_config.get("edges", []) - node_configs = graph_config.get("nodes", []) - - edge_configs = cast(list[dict[str, object]], edge_configs) - node_configs = cast(list[dict[str, object]], node_configs) - node_configs = cls._filter_canvas_only_nodes(node_configs) - node_configs = _ListNodeConfigDict.validate_python(node_configs) - - if not node_configs: - raise ValueError("Graph must have at least one node") - - # Parse node configurations - node_configs_map = cls._parse_node_configs(node_configs) - - if root_node_id not in node_configs_map: - raise ValueError(f"Root node id {root_node_id} not found in the graph") - - # Build edges - edges, in_edges, out_edges = cls._build_edges(edge_configs) - - # Create node instances - nodes = cls._create_node_instances(node_configs_map, node_factory) - - # Promote fail-branch nodes to branch execution type at graph level - cls._promote_fail_branch_nodes(nodes) - - # Get root node instance - root_node = nodes[root_node_id] - - # Mark inactive root branches as skipped - cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id) - - # Create and return the graph - graph = cls( - nodes=nodes, - edges=edges, - in_edges=in_edges, - out_edges=out_edges, - root_node=root_node, - ) - - if not skip_validation: - # Validate the graph structure using built-in validators - get_graph_validator().validate(graph) - - return graph - - @property - def node_ids(self) -> list[str]: - """ - Get list of node IDs (compatibility property for existing code) - - :return: list of node IDs - """ - return list(self.nodes.keys()) - - def get_outgoing_edges(self, node_id: str) -> list[Edge]: - """ - Get all outgoing edges from a node (V2 method) - - :param node_id: node id - :return: list of outgoing edges - """ - edge_ids = self.out_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - def get_incoming_edges(self, node_id: str) -> list[Edge]: - """ - Get all incoming edges to a node (V2 method) - - :param node_id: node id - :return: list of incoming edges - """ - edge_ids = self.in_edges.get(node_id, []) - return [self.edges[eid] for eid in edge_ids if eid in self.edges] - - -@final -class GraphBuilder: - """Fluent helper for constructing simple graphs, primarily for tests.""" - - def __init__(self, *, graph_cls: type[Graph]): - self._graph_cls = graph_cls - self._nodes: list[Node] = [] - self._nodes_by_id: dict[str, Node] = {} - self._edges: list[Edge] = [] - self._edge_counter = 0 - - def add_root(self, node: Node) -> GraphBuilder: - """Register the root node. Must be called exactly once.""" - - if self._nodes: - raise ValueError("Root node has already been added") - self._register_node(node) - self._nodes.append(node) - return self - - def add_node( - self, - node: Node, - *, - from_node_id: str | None = None, - source_handle: str = "source", - ) -> GraphBuilder: - """Append a node and connect it from the specified predecessor.""" - - if not self._nodes: - raise ValueError("Root node must be added before adding other nodes") - - predecessor_id = from_node_id or self._nodes[-1].id - if predecessor_id not in self._nodes_by_id: - raise ValueError(f"Predecessor node '{predecessor_id}' not found") - - predecessor = self._nodes_by_id[predecessor_id] - self._register_node(node) - self._nodes.append(node) - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle) - self._edges.append(edge) - - return self - - def connect(self, *, tail: str, head: str, source_handle: str = "source") -> GraphBuilder: - """Connect two existing nodes without adding a new node.""" - - if tail not in self._nodes_by_id: - raise ValueError(f"Tail node '{tail}' not found") - if head not in self._nodes_by_id: - raise ValueError(f"Head node '{head}' not found") - - edge_id = f"edge_{self._edge_counter}" - self._edge_counter += 1 - edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle) - self._edges.append(edge) - - return self - - def build(self) -> Graph: - """Materialize the graph instance from the accumulated nodes and edges.""" - - if not self._nodes: - raise ValueError("Cannot build an empty graph") - - nodes = {node.id: node for node in self._nodes} - edges = {edge.id: edge for edge in self._edges} - in_edges: dict[str, list[str]] = defaultdict(list) - out_edges: dict[str, list[str]] = defaultdict(list) - - for edge in self._edges: - out_edges[edge.tail].append(edge.id) - in_edges[edge.head].append(edge.id) - - return self._graph_cls( - nodes=nodes, - edges=edges, - in_edges=dict(in_edges), - out_edges=dict(out_edges), - root_node=self._nodes[0], - ) - - def _register_node(self, node: Node) -> None: - if not node.id: - raise ValueError("Node must have a non-empty id") - if node.id in self._nodes_by_id: - raise ValueError(f"Duplicate node id detected: {node.id}") - self._nodes_by_id[node.id] = node diff --git a/api/dify_graph/graph/graph_template.py b/api/dify_graph/graph/graph_template.py deleted file mode 100644 index 34e2dc19e6..0000000000 --- a/api/dify_graph/graph/graph_template.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import Any - -from pydantic import BaseModel, Field - - -class GraphTemplate(BaseModel): - """ - Graph Template for container nodes and subgraph expansion - - According to GraphEngine V2 spec, GraphTemplate contains: - - nodes: mapping of node definitions - - edges: mapping of edge definitions - - root_ids: list of root node IDs - - output_selectors: list of output selectors for the template - """ - - nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping") - edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping") - root_ids: list[str] = Field(default_factory=list, description="root node IDs") - output_selectors: list[str] = Field(default_factory=list, description="output selectors") diff --git a/api/dify_graph/graph/validation.py b/api/dify_graph/graph/validation.py deleted file mode 100644 index 50d1440b04..0000000000 --- a/api/dify_graph/graph/validation.py +++ /dev/null @@ -1,125 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from dataclasses import dataclass -from typing import TYPE_CHECKING, Protocol - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeType - -if TYPE_CHECKING: - from .graph import Graph - - -@dataclass(frozen=True, slots=True) -class GraphValidationIssue: - """Immutable value object describing a single validation issue.""" - - code: str - message: str - node_id: str | None = None - - -class GraphValidationError(ValueError): - """Raised when graph validation fails.""" - - def __init__(self, issues: Sequence[GraphValidationIssue]) -> None: - if not issues: - raise ValueError("GraphValidationError requires at least one issue.") - self.issues: tuple[GraphValidationIssue, ...] = tuple(issues) - message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues) - super().__init__(message) - - -class GraphValidationRule(Protocol): - """Protocol that individual validation rules must satisfy.""" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - """Validate the provided graph and return any discovered issues.""" - ... - - -@dataclass(frozen=True, slots=True) -class _EdgeEndpointValidator: - """Ensures all edges reference existing nodes.""" - - missing_node_code: str = "MISSING_NODE" - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - issues: list[GraphValidationIssue] = [] - for edge in graph.edges.values(): - if edge.tail not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown source node '{edge.tail}'.", - node_id=edge.tail, - ) - ) - if edge.head not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.missing_node_code, - message=f"Edge {edge.id} references unknown target node '{edge.head}'.", - node_id=edge.head, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class _RootNodeValidator: - """Validates root node invariants.""" - - invalid_root_code: str = "INVALID_ROOT" - container_entry_types: tuple[NodeType, ...] = (BuiltinNodeTypes.ITERATION_START, BuiltinNodeTypes.LOOP_START) - - def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]: - root_node = graph.root_node - issues: list[GraphValidationIssue] = [] - if root_node.id not in graph.nodes: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' is missing from the node registry.", - node_id=root_node.id, - ) - ) - return issues - - node_type = root_node.node_type - if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types: - issues.append( - GraphValidationIssue( - code=self.invalid_root_code, - message=f"Root node '{root_node.id}' must declare execution type 'root'.", - node_id=root_node.id, - ) - ) - return issues - - -@dataclass(frozen=True, slots=True) -class GraphValidator: - """Coordinates execution of graph validation rules.""" - - rules: tuple[GraphValidationRule, ...] - - def validate(self, graph: Graph) -> None: - """Validate the graph against all configured rules.""" - issues: list[GraphValidationIssue] = [] - for rule in self.rules: - issues.extend(rule.validate(graph)) - - if issues: - raise GraphValidationError(issues) - - -_DEFAULT_RULES: tuple[GraphValidationRule, ...] = ( - _EdgeEndpointValidator(), - _RootNodeValidator(), -) - - -def get_graph_validator() -> GraphValidator: - """Construct the validator composed of default rules.""" - return GraphValidator(_DEFAULT_RULES) diff --git a/api/dify_graph/graph_engine/__init__.py b/api/dify_graph/graph_engine/__init__.py deleted file mode 100644 index 0e1c7dd60a..0000000000 --- a/api/dify_graph/graph_engine/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .config import GraphEngineConfig -from .graph_engine import GraphEngine - -__all__ = ["GraphEngine", "GraphEngineConfig"] diff --git a/api/dify_graph/graph_engine/_engine_utils.py b/api/dify_graph/graph_engine/_engine_utils.py deleted file mode 100644 index 28898268fe..0000000000 --- a/api/dify_graph/graph_engine/_engine_utils.py +++ /dev/null @@ -1,15 +0,0 @@ -import time - - -def get_timestamp() -> float: - """Retrieve a timestamp as a float point numer representing the number of seconds - since the Unix epoch. - - This function is primarily used to measure the execution time of the workflow engine. - Since workflow execution may be paused and resumed on a different machine, - `time.perf_counter` cannot be used as it is inconsistent across machines. - - To address this, the function uses the wall clock as the time source. - However, it assumes that the clocks of all servers are properly synchronized. - """ - return round(time.time()) diff --git a/api/dify_graph/graph_engine/command_channels/README.md b/api/dify_graph/graph_engine/command_channels/README.md deleted file mode 100644 index e35e12054a..0000000000 --- a/api/dify_graph/graph_engine/command_channels/README.md +++ /dev/null @@ -1,33 +0,0 @@ -# Command Channels - -Channel implementations for external workflow control. - -## Components - -### InMemoryChannel - -Thread-safe in-memory queue for single-process deployments. - -- `fetch_commands()` - Get pending commands -- `send_command()` - Add command to queue - -### RedisChannel - -Redis-based queue for distributed deployments. - -- `fetch_commands()` - Get commands with JSON deserialization -- `send_command()` - Store commands with TTL - -## Usage - -```python -# Local execution -channel = InMemoryChannel() -channel.send_command(AbortCommand(graph_id="workflow-123")) - -# Distributed execution -redis_channel = RedisChannel( - redis_client=redis_client, - channel_key="workflow:123:commands" -) -``` diff --git a/api/dify_graph/graph_engine/command_channels/__init__.py b/api/dify_graph/graph_engine/command_channels/__init__.py deleted file mode 100644 index 863e6032d6..0000000000 --- a/api/dify_graph/graph_engine/command_channels/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -"""Command channel implementations for GraphEngine.""" - -from .in_memory_channel import InMemoryChannel -from .redis_channel import RedisChannel - -__all__ = ["InMemoryChannel", "RedisChannel"] diff --git a/api/dify_graph/graph_engine/command_channels/in_memory_channel.py b/api/dify_graph/graph_engine/command_channels/in_memory_channel.py deleted file mode 100644 index bdaf236796..0000000000 --- a/api/dify_graph/graph_engine/command_channels/in_memory_channel.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -In-memory implementation of CommandChannel for local/testing scenarios. - -This implementation uses a thread-safe queue for command communication -within a single process. Each instance handles commands for one workflow execution. -""" - -from queue import Queue -from typing import final - -from ..entities.commands import GraphEngineCommand - - -@final -class InMemoryChannel: - """ - In-memory command channel implementation using a thread-safe queue. - - Each instance is dedicated to a single GraphEngine/workflow execution. - Suitable for local development, testing, and single-instance deployments. - """ - - def __init__(self) -> None: - """Initialize the in-memory channel with a single queue.""" - self._queue: Queue[GraphEngineCommand] = Queue() - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from the queue. - - Returns: - List of pending commands (drains the queue) - """ - commands: list[GraphEngineCommand] = [] - - # Drain all available commands from the queue - while not self._queue.empty(): - try: - command = self._queue.get_nowait() - commands.append(command) - except Exception: - break - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to this channel's queue. - - Args: - command: The command to send - """ - self._queue.put(command) diff --git a/api/dify_graph/graph_engine/command_channels/redis_channel.py b/api/dify_graph/graph_engine/command_channels/redis_channel.py deleted file mode 100644 index 77cf884c67..0000000000 --- a/api/dify_graph/graph_engine/command_channels/redis_channel.py +++ /dev/null @@ -1,153 +0,0 @@ -""" -Redis-based implementation of CommandChannel for distributed scenarios. - -This implementation uses Redis lists for command queuing, supporting -multi-instance deployments and cross-server communication. -Each instance uses a unique key for its command queue. -""" - -import json -from contextlib import AbstractContextManager -from typing import Any, Protocol, final - -from ..entities.commands import AbortCommand, CommandType, GraphEngineCommand, PauseCommand, UpdateVariablesCommand - - -class RedisPipelineProtocol(Protocol): - """Minimal Redis pipeline contract used by the command channel.""" - - def lrange(self, name: str, start: int, end: int) -> Any: ... - def delete(self, *names: str) -> Any: ... - def execute(self) -> list[Any]: ... - def rpush(self, name: str, *values: str) -> Any: ... - def expire(self, name: str, time: int) -> Any: ... - def set(self, name: str, value: str, ex: int | None = None) -> Any: ... - def get(self, name: str) -> Any: ... - - -class RedisClientProtocol(Protocol): - """Redis client contract required by the command channel.""" - - def pipeline(self) -> AbstractContextManager[RedisPipelineProtocol]: ... - - -@final -class RedisChannel: - """ - Redis-based command channel implementation for distributed systems. - - Each instance uses a unique Redis key for its command queue. - Commands are JSON-serialized for transport. - """ - - def __init__( - self, - redis_client: RedisClientProtocol, - channel_key: str, - command_ttl: int = 3600, - ) -> None: - """ - Initialize the Redis channel. - - Args: - redis_client: Redis client instance - channel_key: Unique key for this channel's command queue - command_ttl: TTL for command keys in seconds (default: 3600) - """ - self._redis = redis_client - self._key = channel_key - self._command_ttl = command_ttl - self._pending_key = f"{channel_key}:pending" - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch all pending commands from Redis. - - Returns: - List of pending commands (drains the Redis list) - """ - if not self._has_pending_commands(): - return [] - - commands: list[GraphEngineCommand] = [] - - # Use pipeline for atomic operations - with self._redis.pipeline() as pipe: - # Get all commands and clear the list atomically - pipe.lrange(self._key, 0, -1) - pipe.delete(self._key) - results = pipe.execute() - - # Parse commands from JSON - if results[0]: - for command_json in results[0]: - try: - command_data = json.loads(command_json) - command = self._deserialize_command(command_data) - if command: - commands.append(command) - except (json.JSONDecodeError, ValueError): - # Skip invalid commands - continue - - return commands - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to Redis. - - Args: - command: The command to send - """ - command_json = json.dumps(command.model_dump()) - - # Push to list and set expiry - with self._redis.pipeline() as pipe: - pipe.rpush(self._key, command_json) - pipe.expire(self._key, self._command_ttl) - pipe.set(self._pending_key, "1", ex=self._command_ttl) - pipe.execute() - - def _deserialize_command(self, data: dict[str, Any]) -> GraphEngineCommand | None: - """ - Deserialize a command from dictionary data. - - Args: - data: Command data dictionary - - Returns: - Deserialized command or None if invalid - """ - command_type_value = data.get("command_type") - if not isinstance(command_type_value, str): - return None - - try: - command_type = CommandType(command_type_value) - - if command_type == CommandType.ABORT: - return AbortCommand.model_validate(data) - if command_type == CommandType.PAUSE: - return PauseCommand.model_validate(data) - if command_type == CommandType.UPDATE_VARIABLES: - return UpdateVariablesCommand.model_validate(data) - - # For other command types, use base class - return GraphEngineCommand.model_validate(data) - - except (ValueError, TypeError): - return None - - def _has_pending_commands(self) -> bool: - """ - Check and consume the pending marker to avoid unnecessary list reads. - - Returns: - True if commands should be fetched from Redis. - """ - with self._redis.pipeline() as pipe: - pipe.get(self._pending_key) - pipe.delete(self._pending_key) - pending_value, _ = pipe.execute() - - return pending_value is not None diff --git a/api/dify_graph/graph_engine/command_processing/__init__.py b/api/dify_graph/graph_engine/command_processing/__init__.py deleted file mode 100644 index 7b4f0dfff7..0000000000 --- a/api/dify_graph/graph_engine/command_processing/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Command processing subsystem for graph engine. - -This package handles external commands sent to the engine -during execution. -""" - -from .command_handlers import AbortCommandHandler, PauseCommandHandler, UpdateVariablesCommandHandler -from .command_processor import CommandProcessor - -__all__ = [ - "AbortCommandHandler", - "CommandProcessor", - "PauseCommandHandler", - "UpdateVariablesCommandHandler", -] diff --git a/api/dify_graph/graph_engine/command_processing/command_handlers.py b/api/dify_graph/graph_engine/command_processing/command_handlers.py deleted file mode 100644 index eefd0c366b..0000000000 --- a/api/dify_graph/graph_engine/command_processing/command_handlers.py +++ /dev/null @@ -1,56 +0,0 @@ -import logging -from typing import final - -from typing_extensions import override - -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.runtime import VariablePool - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import AbortCommand, GraphEngineCommand, PauseCommand, UpdateVariablesCommand -from .command_processor import CommandHandler - -logger = logging.getLogger(__name__) - - -@final -class AbortCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, AbortCommand) - logger.debug("Aborting workflow %s: %s", execution.workflow_id, command.reason) - execution.abort(command.reason or "User requested abort") - - -@final -class PauseCommandHandler(CommandHandler): - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, PauseCommand) - logger.debug("Pausing workflow %s: %s", execution.workflow_id, command.reason) - # Convert string reason to PauseReason if needed - reason = command.reason - pause_reason = SchedulingPause(message=reason) - execution.pause(pause_reason) - - -@final -class UpdateVariablesCommandHandler(CommandHandler): - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - @override - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: - assert isinstance(command, UpdateVariablesCommand) - for update in command.updates: - try: - variable = update.value - self._variable_pool.add(variable.selector, variable) - logger.debug("Updated variable %s for workflow %s", variable.selector, execution.workflow_id) - except ValueError as exc: - logger.warning( - "Skipping invalid variable selector %s for workflow %s: %s", - getattr(update.value, "selector", None), - execution.workflow_id, - exc, - ) diff --git a/api/dify_graph/graph_engine/command_processing/command_processor.py b/api/dify_graph/graph_engine/command_processing/command_processor.py deleted file mode 100644 index 942c2d77a5..0000000000 --- a/api/dify_graph/graph_engine/command_processing/command_processor.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -Main command processor for handling external commands. -""" - -import logging -from typing import Protocol, final - -from ..domain.graph_execution import GraphExecution -from ..entities.commands import GraphEngineCommand -from ..protocols.command_channel import CommandChannel - -logger = logging.getLogger(__name__) - - -class CommandHandler(Protocol): - """Protocol for command handlers.""" - - def handle(self, command: GraphEngineCommand, execution: GraphExecution) -> None: ... - - -@final -class CommandProcessor: - """ - Processes external commands sent to the engine. - - This polls the command channel and dispatches commands to - appropriate handlers. - """ - - def __init__( - self, - command_channel: CommandChannel, - graph_execution: GraphExecution, - ) -> None: - """ - Initialize the command processor. - - Args: - command_channel: Channel for receiving commands - graph_execution: Graph execution aggregate - """ - self._command_channel = command_channel - self._graph_execution = graph_execution - self._handlers: dict[type[GraphEngineCommand], CommandHandler] = {} - - def register_handler(self, command_type: type[GraphEngineCommand], handler: CommandHandler) -> None: - """ - Register a handler for a command type. - - Args: - command_type: Type of command to handle - handler: Handler for the command - """ - self._handlers[command_type] = handler - - def process_commands(self) -> None: - """Check for and process any pending commands.""" - try: - commands = self._command_channel.fetch_commands() - for command in commands: - self._handle_command(command) - except Exception as e: - logger.warning("Error processing commands: %s", e) - - def _handle_command(self, command: GraphEngineCommand) -> None: - """ - Handle a single command. - - Args: - command: The command to handle - """ - handler = self._handlers.get(type(command)) - if handler: - try: - handler.handle(command, self._graph_execution) - except Exception: - logger.exception("Error handling command %s", command.__class__.__name__) - else: - logger.warning("No handler registered for command: %s", command.__class__.__name__) diff --git a/api/dify_graph/graph_engine/config.py b/api/dify_graph/graph_engine/config.py deleted file mode 100644 index d56a69cee0..0000000000 --- a/api/dify_graph/graph_engine/config.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -GraphEngine configuration models. -""" - -from pydantic import BaseModel, ConfigDict - - -class GraphEngineConfig(BaseModel): - """Configuration for GraphEngine worker pool scaling.""" - - model_config = ConfigDict(frozen=True) - - min_workers: int = 1 - max_workers: int = 5 - scale_up_threshold: int = 3 - scale_down_idle_time: float = 5.0 diff --git a/api/dify_graph/graph_engine/domain/__init__.py b/api/dify_graph/graph_engine/domain/__init__.py deleted file mode 100644 index 9e9afe4c21..0000000000 --- a/api/dify_graph/graph_engine/domain/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Domain models for graph engine. - -This package contains the core domain entities, value objects, and aggregates -that represent the business concepts of workflow graph execution. -""" - -from .graph_execution import GraphExecution -from .node_execution import NodeExecution - -__all__ = [ - "GraphExecution", - "NodeExecution", -] diff --git a/api/dify_graph/graph_engine/domain/graph_execution.py b/api/dify_graph/graph_engine/domain/graph_execution.py deleted file mode 100644 index 0ee4a9f9a7..0000000000 --- a/api/dify_graph/graph_engine/domain/graph_execution.py +++ /dev/null @@ -1,242 +0,0 @@ -"""GraphExecution aggregate root managing the overall graph execution state.""" - -from __future__ import annotations - -from dataclasses import dataclass, field -from importlib import import_module -from typing import Literal - -from pydantic import BaseModel, Field - -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.enums import NodeState -from dify_graph.runtime.graph_runtime_state import GraphExecutionProtocol - -from .node_execution import NodeExecution - - -class GraphExecutionErrorState(BaseModel): - """Serializable representation of an execution error.""" - - module: str = Field(description="Module containing the exception class") - qualname: str = Field(description="Qualified name of the exception class") - message: str | None = Field(default=None, description="Exception message string") - - -class NodeExecutionState(BaseModel): - """Serializable representation of a node execution entity.""" - - node_id: str - state: NodeState = Field(default=NodeState.UNKNOWN) - retry_count: int = Field(default=0) - execution_id: str | None = Field(default=None) - error: str | None = Field(default=None) - - -class GraphExecutionState(BaseModel): - """Pydantic model describing serialized GraphExecution state.""" - - type: Literal["GraphExecution"] = Field(default="GraphExecution") - version: str = Field(default="1.0") - workflow_id: str - started: bool = Field(default=False) - completed: bool = Field(default=False) - aborted: bool = Field(default=False) - paused: bool = Field(default=False) - pause_reasons: list[PauseReason] = Field(default_factory=list) - error: GraphExecutionErrorState | None = Field(default=None) - exceptions_count: int = Field(default=0) - node_executions: list[NodeExecutionState] = Field(default_factory=list[NodeExecutionState]) - - -def _serialize_error(error: Exception | None) -> GraphExecutionErrorState | None: - """Convert an exception into its serializable representation.""" - - if error is None: - return None - - return GraphExecutionErrorState( - module=error.__class__.__module__, - qualname=error.__class__.__qualname__, - message=str(error), - ) - - -def _resolve_exception_class(module_name: str, qualname: str) -> type[Exception]: - """Locate an exception class from its module and qualified name.""" - - module = import_module(module_name) - attr: object = module - for part in qualname.split("."): - attr = getattr(attr, part) - - if isinstance(attr, type) and issubclass(attr, Exception): - return attr - - raise TypeError(f"{qualname} in {module_name} is not an Exception subclass") - - -def _deserialize_error(state: GraphExecutionErrorState | None) -> Exception | None: - """Reconstruct an exception instance from serialized data.""" - - if state is None: - return None - - try: - exception_class = _resolve_exception_class(state.module, state.qualname) - if state.message is None: - return exception_class() - return exception_class(state.message) - except Exception: - # Fallback to RuntimeError when reconstruction fails - if state.message is None: - return RuntimeError(state.qualname) - return RuntimeError(state.message) - - -@dataclass -class GraphExecution: - """ - Aggregate root for graph execution. - - This manages the overall execution state of a workflow graph, - coordinating between multiple node executions. - """ - - workflow_id: str - started: bool = False - completed: bool = False - aborted: bool = False - paused: bool = False - pause_reasons: list[PauseReason] = field(default_factory=list) - error: Exception | None = None - node_executions: dict[str, NodeExecution] = field(default_factory=dict[str, NodeExecution]) - exceptions_count: int = 0 - - def start(self) -> None: - """Mark the graph execution as started.""" - if self.started: - raise RuntimeError("Graph execution already started") - self.started = True - - def complete(self) -> None: - """Mark the graph execution as completed.""" - if not self.started: - raise RuntimeError("Cannot complete execution that hasn't started") - if self.completed: - raise RuntimeError("Graph execution already completed") - self.completed = True - - def abort(self, reason: str) -> None: - """Abort the graph execution.""" - self.aborted = True - self.error = RuntimeError(f"Aborted: {reason}") - - def pause(self, reason: PauseReason) -> None: - """Pause the graph execution without marking it complete.""" - if self.completed: - raise RuntimeError("Cannot pause execution that has completed") - if self.aborted: - raise RuntimeError("Cannot pause execution that has been aborted") - self.paused = True - self.pause_reasons.append(reason) - - def fail(self, error: Exception) -> None: - """Mark the graph execution as failed.""" - self.error = error - self.completed = True - - def get_or_create_node_execution(self, node_id: str) -> NodeExecution: - """Get or create a node execution entity.""" - if node_id not in self.node_executions: - self.node_executions[node_id] = NodeExecution(node_id=node_id) - return self.node_executions[node_id] - - @property - def is_running(self) -> bool: - """Check if the execution is currently running.""" - return self.started and not self.completed and not self.aborted and not self.paused - - @property - def is_paused(self) -> bool: - """Check if the execution is currently paused.""" - return self.paused - - @property - def has_error(self) -> bool: - """Check if the execution has encountered an error.""" - return self.error is not None - - @property - def error_message(self) -> str | None: - """Get the error message if an error exists.""" - if not self.error: - return None - return str(self.error) - - def dumps(self) -> str: - """Serialize the aggregate state into a JSON string.""" - - node_states = [ - NodeExecutionState( - node_id=node_id, - state=node_execution.state, - retry_count=node_execution.retry_count, - execution_id=node_execution.execution_id, - error=node_execution.error, - ) - for node_id, node_execution in sorted(self.node_executions.items()) - ] - - state = GraphExecutionState( - workflow_id=self.workflow_id, - started=self.started, - completed=self.completed, - aborted=self.aborted, - paused=self.paused, - pause_reasons=self.pause_reasons, - error=_serialize_error(self.error), - exceptions_count=self.exceptions_count, - node_executions=node_states, - ) - - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore aggregate state from a serialized JSON string.""" - - state = GraphExecutionState.model_validate_json(data) - - if state.type != "GraphExecution": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - if self.workflow_id != state.workflow_id: - raise ValueError("Serialized workflow_id does not match aggregate identity") - - self.started = state.started - self.completed = state.completed - self.aborted = state.aborted - self.paused = state.paused - self.pause_reasons = state.pause_reasons - self.error = _deserialize_error(state.error) - self.exceptions_count = state.exceptions_count - self.node_executions = { - item.node_id: NodeExecution( - node_id=item.node_id, - state=item.state, - retry_count=item.retry_count, - execution_id=item.execution_id, - error=item.error, - ) - for item in state.node_executions - } - - def record_node_failure(self) -> None: - """Increment the count of node failures encountered during execution.""" - self.exceptions_count += 1 - - -_: GraphExecutionProtocol = GraphExecution(workflow_id="") diff --git a/api/dify_graph/graph_engine/domain/node_execution.py b/api/dify_graph/graph_engine/domain/node_execution.py deleted file mode 100644 index ae8f9a5e50..0000000000 --- a/api/dify_graph/graph_engine/domain/node_execution.py +++ /dev/null @@ -1,45 +0,0 @@ -""" -NodeExecution entity representing a node's execution state. -""" - -from dataclasses import dataclass - -from dify_graph.enums import NodeState - - -@dataclass -class NodeExecution: - """ - Entity representing the execution state of a single node. - - This is a mutable entity that tracks the runtime state of a node - during graph execution. - """ - - node_id: str - state: NodeState = NodeState.UNKNOWN - retry_count: int = 0 - execution_id: str | None = None - error: str | None = None - - def mark_started(self, execution_id: str) -> None: - """Mark the node as started with an execution ID.""" - self.state = NodeState.TAKEN - self.execution_id = execution_id - - def mark_taken(self) -> None: - """Mark the node as successfully completed.""" - self.state = NodeState.TAKEN - self.error = None - - def mark_failed(self, error: str) -> None: - """Mark the node as failed with an error.""" - self.error = error - - def mark_skipped(self) -> None: - """Mark the node as skipped.""" - self.state = NodeState.SKIPPED - - def increment_retry(self) -> None: - """Increment the retry count for this node.""" - self.retry_count += 1 diff --git a/api/dify_graph/graph_engine/entities/commands.py b/api/dify_graph/graph_engine/entities/commands.py deleted file mode 100644 index c56845cfc4..0000000000 --- a/api/dify_graph/graph_engine/entities/commands.py +++ /dev/null @@ -1,56 +0,0 @@ -""" -GraphEngine command entities for external control. - -This module defines command types that can be sent to a running GraphEngine -instance to control its execution flow. -""" - -from collections.abc import Sequence -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, Field - -from dify_graph.variables.variables import Variable - - -class CommandType(StrEnum): - """Types of commands that can be sent to GraphEngine.""" - - ABORT = auto() - PAUSE = auto() - UPDATE_VARIABLES = auto() - - -class GraphEngineCommand(BaseModel): - """Base class for all GraphEngine commands.""" - - command_type: CommandType = Field(..., description="Type of command") - payload: dict[str, Any] | None = Field(default=None, description="Optional command payload") - - -class AbortCommand(GraphEngineCommand): - """Command to abort a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.ABORT, description="Type of command") - reason: str | None = Field(default=None, description="Optional reason for abort") - - -class PauseCommand(GraphEngineCommand): - """Command to pause a running workflow execution.""" - - command_type: CommandType = Field(default=CommandType.PAUSE, description="Type of command") - reason: str = Field(default="unknown reason", description="reason for pause") - - -class VariableUpdate(BaseModel): - """Represents a single variable update instruction.""" - - value: Variable = Field(description="New variable value") - - -class UpdateVariablesCommand(GraphEngineCommand): - """Command to update a group of variables in the variable pool.""" - - command_type: CommandType = Field(default=CommandType.UPDATE_VARIABLES, description="Type of command") - updates: Sequence[VariableUpdate] = Field(default_factory=list, description="Variable updates") diff --git a/api/dify_graph/graph_engine/error_handler.py b/api/dify_graph/graph_engine/error_handler.py deleted file mode 100644 index e206f21592..0000000000 --- a/api/dify_graph/graph_engine/error_handler.py +++ /dev/null @@ -1,213 +0,0 @@ -""" -Main error handler that coordinates error strategies. -""" - -import logging -import time -from typing import TYPE_CHECKING, final - -from dify_graph.enums import ( - ErrorStrategy as ErrorStrategyEnum, -) -from dify_graph.enums import ( - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunRetryEvent, -) -from dify_graph.node_events import NodeRunResult - -if TYPE_CHECKING: - from .domain import GraphExecution - -logger = logging.getLogger(__name__) - - -@final -class ErrorHandler: - """ - Coordinates error handling strategies for node failures. - - This acts as a facade for the various error strategies, - selecting and applying the appropriate strategy based on - node configuration. - """ - - def __init__(self, graph: Graph, graph_execution: "GraphExecution") -> None: - """ - Initialize the error handler. - - Args: - graph: The workflow graph - graph_execution: The graph execution state - """ - self._graph = graph - self._graph_execution = graph_execution - - def handle_node_failure(self, event: NodeRunFailedEvent) -> GraphNodeEventBase | None: - """ - Handle a node failure event. - - Selects and applies the appropriate error strategy based on - the node's configuration. - - Args: - event: The node failure event - - Returns: - Optional new event to process, or None to abort - """ - node = self._graph.nodes[event.node_id] - # Get retry count from NodeExecution - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - retry_count = node_execution.retry_count - - # First check if retry is configured and not exhausted - if node.retry and retry_count < node.retry_config.max_retries: - result = self._handle_retry(event, retry_count) - if result: - # Retry count will be incremented when NodeRunRetryEvent is handled - return result - - # Apply configured error strategy - strategy = node.error_strategy - - match strategy: - case None: - return self._handle_abort(event) - case ErrorStrategyEnum.FAIL_BRANCH: - return self._handle_fail_branch(event) - case ErrorStrategyEnum.DEFAULT_VALUE: - return self._handle_default_value(event) - - def _handle_abort(self, event: NodeRunFailedEvent): - """ - Handle error by aborting execution. - - This is the default strategy when no other strategy is specified. - It stops the entire graph execution when a node fails. - - Args: - event: The failure event - - Returns: - None - signals abortion - """ - logger.error("Node %s failed with ABORT strategy: %s", event.node_id, event.error) - # Return None to signal that execution should stop - - def _handle_retry(self, event: NodeRunFailedEvent, retry_count: int): - """ - Handle error by retrying the node. - - This strategy re-attempts node execution up to a configured - maximum number of retries with configurable intervals. - - Args: - event: The failure event - retry_count: Current retry attempt count - - Returns: - NodeRunRetryEvent if retry should occur, None otherwise - """ - node = self._graph.nodes[event.node_id] - - # Check if we've exceeded max retries - if not node.retry or retry_count >= node.retry_config.max_retries: - return None - - # Wait for retry interval - time.sleep(node.retry_config.retry_interval_seconds) - - # Create retry event - return NodeRunRetryEvent( - id=event.id, - node_title=node.title, - node_id=event.node_id, - node_type=event.node_type, - node_run_result=event.node_run_result, - start_at=event.start_at, - error=event.error, - retry_index=retry_count + 1, - ) - - def _handle_fail_branch(self, event: NodeRunFailedEvent): - """ - Handle error by taking the fail branch. - - This strategy converts failures to exceptions and routes execution - through a designated fail-branch edge. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent to continue via fail branch - """ - outputs = { - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - edge_source_handle="fail-branch", - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.FAIL_BRANCH, - }, - ), - error=event.error, - ) - - def _handle_default_value(self, event: NodeRunFailedEvent): - """ - Handle error by using default values. - - This strategy allows nodes to fail gracefully by providing - predefined default output values. - - Args: - event: The failure event - - Returns: - NodeRunExceptionEvent with default values - """ - node = self._graph.nodes[event.node_id] - - outputs = { - **node.default_value_dict, - "error_message": event.node_run_result.error, - "error_type": event.node_run_result.error_type, - } - - return NodeRunExceptionEvent( - id=event.id, - node_id=event.node_id, - node_type=event.node_type, - start_at=event.start_at, - finished_at=event.finished_at, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.EXCEPTION, - inputs=event.node_run_result.inputs, - process_data=event.node_run_result.process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.ERROR_STRATEGY: ErrorStrategyEnum.DEFAULT_VALUE, - }, - ), - error=event.error, - ) diff --git a/api/dify_graph/graph_engine/event_management/__init__.py b/api/dify_graph/graph_engine/event_management/__init__.py deleted file mode 100644 index f6c3c0f753..0000000000 --- a/api/dify_graph/graph_engine/event_management/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Event management subsystem for graph engine. - -This package handles event routing, collection, and emission for -workflow graph execution events. -""" - -from .event_handlers import EventHandler -from .event_manager import EventManager - -__all__ = [ - "EventHandler", - "EventManager", -] diff --git a/api/dify_graph/graph_engine/event_management/event_handlers.py b/api/dify_graph/graph_engine/event_management/event_handlers.py deleted file mode 100644 index 7f5ad40e0e..0000000000 --- a/api/dify_graph/graph_engine/event_management/event_handlers.py +++ /dev/null @@ -1,351 +0,0 @@ -""" -Event handler implementations for different event types. -""" - -import logging -from collections.abc import Mapping -from functools import singledispatchmethod -from typing import TYPE_CHECKING, final - -from dify_graph.enums import ErrorStrategy, NodeExecutionType, NodeState -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState - -from ..domain.graph_execution import GraphExecution -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from ..error_handler import ErrorHandler - from ..graph_state_manager import GraphStateManager - from ..graph_traversal import EdgeProcessor - from .event_manager import EventManager - -logger = logging.getLogger(__name__) - - -@final -class EventHandler: - """ - Registry of event handlers for different event types. - - This centralizes the business logic for handling specific events, - keeping it separate from the routing and collection infrastructure. - """ - - def __init__( - self, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - graph_execution: GraphExecution, - response_coordinator: ResponseStreamCoordinator, - event_collector: "EventManager", - edge_processor: "EdgeProcessor", - state_manager: "GraphStateManager", - error_handler: "ErrorHandler", - ) -> None: - """ - Initialize the event handler registry. - - Args: - graph: The workflow graph - graph_runtime_state: Runtime state with variable pool - graph_execution: Graph execution aggregate - response_coordinator: Response stream coordinator - event_collector: Event manager for collecting events - edge_processor: Edge processor for edge traversal - state_manager: Unified state manager - error_handler: Error handler - """ - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - self._event_collector = event_collector - self._edge_processor = edge_processor - self._state_manager = state_manager - self._error_handler = error_handler - - def dispatch(self, event: GraphNodeEventBase) -> None: - """ - Handle any node event by dispatching to the appropriate handler. - - Args: - event: The event to handle - """ - # Events in loops or iterations are always collected - if event.in_loop_id or event.in_iteration_id: - self._event_collector.collect(event) - return - return self._dispatch(event) - - @singledispatchmethod - def _dispatch(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - logger.warning("Unhandled event type: %s", type(event).__name__) - - @_dispatch.register(NodeRunIterationStartedEvent) - @_dispatch.register(NodeRunIterationNextEvent) - @_dispatch.register(NodeRunIterationSucceededEvent) - @_dispatch.register(NodeRunIterationFailedEvent) - @_dispatch.register(NodeRunLoopStartedEvent) - @_dispatch.register(NodeRunLoopNextEvent) - @_dispatch.register(NodeRunLoopSucceededEvent) - @_dispatch.register(NodeRunLoopFailedEvent) - @_dispatch.register(NodeRunAgentLogEvent) - @_dispatch.register(NodeRunRetrieverResourceEvent) - def _(self, event: GraphNodeEventBase) -> None: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStartedEvent) -> None: - """ - Handle node started event. - - Args: - event: The node started event - """ - # Track execution in domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - is_initial_attempt = node_execution.retry_count == 0 - node_execution.mark_started(event.id) - self._graph_runtime_state.increment_node_run_steps() - - # Track in response coordinator for stream ordering - self._response_coordinator.track_node_execution(event.node_id, event.id) - - # Collect the event only for the first attempt; retries remain silent - if is_initial_attempt: - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunStreamChunkEvent) -> None: - """ - Handle stream chunk event with full processing. - - Args: - event: The stream chunk event - """ - # Process with response coordinator - streaming_events = list(self._response_coordinator.intercept_event(event)) - - # Collect all events - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - @_dispatch.register - def _(self, event: NodeRunSucceededEvent) -> None: - """ - Handle node success by coordinating subsystems. - - This method coordinates between different subsystems to process - node completion, handle edges, and trigger downstream execution. - - Args: - event: The node succeeded event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Store outputs in variable pool - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - # Forward to response coordinator and emit streaming events - streaming_events = self._response_coordinator.intercept_event(event) - for stream_event in streaming_events: - self._event_collector.collect(stream_event) - - # Process edges and get ready nodes - node = self._graph.nodes[event.node_id] - if node.execution_type == NodeExecutionType.BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - - # Collect streaming events from edge processing - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - # Enqueue ready nodes - if self._graph_execution.is_paused: - for node_id in ready_nodes: - self._graph_runtime_state.register_deferred_node(node_id) - else: - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update execution tracking - self._state_manager.finish_execution(event.node_id) - - # Handle response node outputs - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - # Collect the event - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunPauseRequestedEvent) -> None: - """Handle pause requests emitted by nodes.""" - - pause_reason = event.reason - self._graph_execution.pause(pause_reason) - self._state_manager.finish_execution(event.node_id) - if event.node_id in self._graph.nodes: - self._graph.nodes[event.node_id].state = NodeState.UNKNOWN - self._graph_runtime_state.register_paused_node(event.node_id) - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunFailedEvent) -> None: - """ - Handle node failure using error handler. - - Args: - event: The node failed event - """ - # Update domain model - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_failed(event.error) - self._graph_execution.record_node_failure() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - result = self._error_handler.handle_node_failure(event) - - if result: - # Process the resulting event (retry, exception, etc.) - self.dispatch(result) - else: - # Abort execution - self._graph_execution.fail(RuntimeError(event.error)) - self._event_collector.collect(event) - self._state_manager.finish_execution(event.node_id) - - @_dispatch.register - def _(self, event: NodeRunExceptionEvent) -> None: - """ - Handle node exception event (fail-branch strategy). - - Args: - event: The node exception event - """ - # Node continues via fail-branch/default-value, treat as completion - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.mark_taken() - - self._accumulate_node_usage(event.node_run_result.llm_usage) - - # Persist outputs produced by the exception strategy (e.g. default values) - self._store_node_outputs(event.node_id, event.node_run_result.outputs) - - node = self._graph.nodes[event.node_id] - - if node.error_strategy == ErrorStrategy.DEFAULT_VALUE: - ready_nodes, edge_streaming_events = self._edge_processor.process_node_success(event.node_id) - elif node.error_strategy == ErrorStrategy.FAIL_BRANCH: - ready_nodes, edge_streaming_events = self._edge_processor.handle_branch_completion( - event.node_id, event.node_run_result.edge_source_handle - ) - else: - raise NotImplementedError(f"Unsupported error strategy: {node.error_strategy}") - - for edge_event in edge_streaming_events: - self._event_collector.collect(edge_event) - - for node_id in ready_nodes: - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Update response outputs if applicable - if node.execution_type == NodeExecutionType.RESPONSE: - self._update_response_outputs(event.node_run_result.outputs) - - self._state_manager.finish_execution(event.node_id) - - # Collect the exception event for observers - self._event_collector.collect(event) - - @_dispatch.register - def _(self, event: NodeRunRetryEvent) -> None: - """ - Handle node retry event. - - Args: - event: The node retry event - """ - node_execution = self._graph_execution.get_or_create_node_execution(event.node_id) - node_execution.increment_retry() - - # Finish the previous attempt before re-queuing the node - self._state_manager.finish_execution(event.node_id) - - # Emit retry event for observers - self._event_collector.collect(event) - - # Re-queue node for execution - self._state_manager.enqueue_node(event.node_id) - self._state_manager.start_execution(event.node_id) - - def _accumulate_node_usage(self, usage: LLMUsage) -> None: - """Accumulate token usage into the shared runtime state.""" - if usage.total_tokens <= 0: - return - - self._graph_runtime_state.add_tokens(usage.total_tokens) - - current_usage = self._graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self._graph_runtime_state.llm_usage = usage - else: - self._graph_runtime_state.llm_usage = current_usage.plus(usage) - - def _store_node_outputs(self, node_id: str, outputs: Mapping[str, object]) -> None: - """ - Store node outputs in the variable pool. - - Args: - event: The node succeeded event containing outputs - """ - for variable_name, variable_value in outputs.items(): - self._graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value) - - def _update_response_outputs(self, outputs: Mapping[str, object]) -> None: - """Update response outputs for response nodes.""" - # TODO: Design a mechanism for nodes to notify the engine about how to update outputs - # in runtime state, rather than allowing nodes to directly access runtime state. - for key, value in outputs.items(): - if key == "answer": - existing = self._graph_runtime_state.get_output("answer", "") - if existing: - self._graph_runtime_state.set_output("answer", f"{existing}{value}") - else: - self._graph_runtime_state.set_output("answer", value) - else: - self._graph_runtime_state.set_output(key, value) diff --git a/api/dify_graph/graph_engine/event_management/event_manager.py b/api/dify_graph/graph_engine/event_management/event_manager.py deleted file mode 100644 index 616f621c3e..0000000000 --- a/api/dify_graph/graph_engine/event_management/event_manager.py +++ /dev/null @@ -1,186 +0,0 @@ -""" -Unified event manager for collecting and emitting events. -""" - -import logging -import threading -import time -from collections.abc import Generator -from contextlib import contextmanager -from typing import final - -from dify_graph.graph_events import GraphEngineEvent - -from ..layers.base import GraphEngineLayer - -_logger = logging.getLogger(__name__) - - -@final -class ReadWriteLock: - """ - A read-write lock implementation that allows multiple concurrent readers - but only one writer at a time. - """ - - def __init__(self) -> None: - self._read_ready = threading.Condition(threading.RLock()) - self._readers = 0 - - def acquire_read(self) -> None: - """Acquire a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers += 1 - finally: - self._read_ready.release() - - def release_read(self) -> None: - """Release a read lock.""" - _ = self._read_ready.acquire() - try: - self._readers -= 1 - if self._readers == 0: - self._read_ready.notify_all() - finally: - self._read_ready.release() - - def acquire_write(self) -> None: - """Acquire a write lock.""" - _ = self._read_ready.acquire() - while self._readers > 0: - _ = self._read_ready.wait() - - def release_write(self) -> None: - """Release a write lock.""" - self._read_ready.release() - - @contextmanager - def read_lock(self): - """Return a context manager for read locking.""" - self.acquire_read() - try: - yield - finally: - self.release_read() - - @contextmanager - def write_lock(self): - """Return a context manager for write locking.""" - self.acquire_write() - try: - yield - finally: - self.release_write() - - -@final -class EventManager: - """ - Unified event manager that collects, buffers, and emits events. - - This class combines event collection with event emission, providing - thread-safe event management with support for notifying layers and - streaming events to external consumers. - """ - - def __init__(self) -> None: - """Initialize the event manager.""" - self._events: list[GraphEngineEvent] = [] - self._lock = ReadWriteLock() - self._layers: list[GraphEngineLayer] = [] - self._execution_complete = threading.Event() - - def set_layers(self, layers: list[GraphEngineLayer]) -> None: - """ - Set the layers to notify on event collection. - - Args: - layers: List of layers to notify - """ - self._layers = layers - - def notify_layers(self, event: GraphEngineEvent) -> None: - """Notify registered layers about an event without buffering it.""" - self._notify_layers(event) - - def collect(self, event: GraphEngineEvent) -> None: - """ - Thread-safe method to collect an event. - - Args: - event: The event to collect - """ - with self._lock.write_lock(): - self._events.append(event) - - # NOTE: `_notify_layers` is intentionally called outside the critical section - # to minimize lock contention and avoid blocking other readers or writers. - # - # The public `notify_layers` method also does not use a write lock, - # so protecting `_notify_layers` with a lock here is unnecessary. - self._notify_layers(event) - - def _get_new_events(self, start_index: int) -> list[GraphEngineEvent]: - """ - Get new events starting from a specific index. - - Args: - start_index: The index to start from - - Returns: - List of new events - """ - with self._lock.read_lock(): - return list(self._events[start_index:]) - - def _event_count(self) -> int: - """ - Get the current count of collected events. - - Returns: - Number of collected events - """ - with self._lock.read_lock(): - return len(self._events) - - def mark_complete(self) -> None: - """Mark execution as complete to stop the event emission generator.""" - self._execution_complete.set() - - def emit_events(self) -> Generator[GraphEngineEvent, None, None]: - """ - Generator that yields events as they're collected. - - Yields: - GraphEngineEvent instances as they're processed - """ - yielded_count = 0 - - while not self._execution_complete.is_set() or yielded_count < self._event_count(): - # Get new events since last yield - new_events = self._get_new_events(yielded_count) - - # Yield any new events - for event in new_events: - yield event - yielded_count += 1 - - # Small sleep to avoid busy waiting - if not self._execution_complete.is_set() and not new_events: - time.sleep(0.001) - - def _notify_layers(self, event: GraphEngineEvent) -> None: - """ - Notify all layers of an event. - - Layer exceptions are caught and logged to prevent disrupting collection. - - Args: - event: The event to send to layers - """ - for layer in self._layers: - try: - layer.on_event(event) - except Exception: - _logger.exception("Error in layer on_event, layer_type=%s", type(layer)) diff --git a/api/dify_graph/graph_engine/graph_engine.py b/api/dify_graph/graph_engine/graph_engine.py deleted file mode 100644 index ea98a46b06..0000000000 --- a/api/dify_graph/graph_engine/graph_engine.py +++ /dev/null @@ -1,384 +0,0 @@ -""" -QueueBasedGraphEngine - Main orchestrator for queue-based workflow execution. - -This engine uses a modular architecture with separated packages following -Domain-Driven Design principles for improved maintainability and testability. -""" - -from __future__ import annotations - -import logging -import queue -from collections.abc import Generator, Mapping -from typing import TYPE_CHECKING, cast, final - -from dify_graph.context import capture_current_context -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import NodeExecutionType -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphEngineEvent, - GraphNodeEventBase, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.graph_runtime_state import ChildGraphEngineBuilderProtocol - -if TYPE_CHECKING: # pragma: no cover - used only for static analysis - from dify_graph.runtime.graph_runtime_state import GraphProtocol - -from .command_processing import ( - AbortCommandHandler, - CommandProcessor, - PauseCommandHandler, - UpdateVariablesCommandHandler, -) -from .config import GraphEngineConfig -from .entities.commands import AbortCommand, PauseCommand, UpdateVariablesCommand -from .error_handler import ErrorHandler -from .event_management import EventHandler, EventManager -from .graph_state_manager import GraphStateManager -from .graph_traversal import EdgeProcessor, SkipPropagator -from .layers.base import GraphEngineLayer -from .orchestration import Dispatcher, ExecutionCoordinator -from .protocols.command_channel import CommandChannel -from .worker_management import WorkerPool - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.graph_engine.domain.graph_execution import GraphExecution - from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator - -logger = logging.getLogger(__name__) - - -_DEFAULT_CONFIG = GraphEngineConfig() - - -@final -class GraphEngine: - """ - Queue-based graph execution engine. - - Uses a modular architecture that delegates responsibilities to specialized - subsystems, following Domain-Driven Design and SOLID principles. - """ - - def __init__( - self, - workflow_id: str, - graph: Graph, - graph_runtime_state: GraphRuntimeState, - command_channel: CommandChannel, - config: GraphEngineConfig = _DEFAULT_CONFIG, - child_engine_builder: ChildGraphEngineBuilderProtocol | None = None, - ) -> None: - """Initialize the graph engine with all subsystems and dependencies.""" - - # Bind runtime state to current workflow context - self._graph = graph - self._graph_runtime_state = graph_runtime_state - self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph)) - self._command_channel = command_channel - self._config = config - self._child_engine_builder = child_engine_builder - if child_engine_builder is not None: - self._graph_runtime_state.bind_child_engine_builder(child_engine_builder) - - # Graph execution tracks the overall execution state - self._graph_execution = cast("GraphExecution", self._graph_runtime_state.graph_execution) - self._graph_execution.workflow_id = workflow_id - - # === Execution Queues === - self._ready_queue = self._graph_runtime_state.ready_queue - - # Queue for events generated during execution - self._event_queue: queue.Queue[GraphNodeEventBase] = queue.Queue() - - # === State Management === - # Unified state manager handles all node state transitions and queue operations - self._state_manager = GraphStateManager(self._graph, self._ready_queue) - - # === Response Coordination === - # Coordinates response streaming from response nodes - self._response_coordinator = cast("ResponseStreamCoordinator", self._graph_runtime_state.response_coordinator) - - # === Event Management === - # Event manager handles both collection and emission of events - self._event_manager = EventManager() - - # === Error Handling === - # Centralized error handler for graph execution errors - self._error_handler = ErrorHandler(self._graph, self._graph_execution) - - # === Graph Traversal Components === - # Propagates skip status through the graph when conditions aren't met - self._skip_propagator = SkipPropagator( - graph=self._graph, - state_manager=self._state_manager, - ) - - # Processes edges to determine next nodes after execution - # Also handles conditional branching and route selection - self._edge_processor = EdgeProcessor( - graph=self._graph, - state_manager=self._state_manager, - response_coordinator=self._response_coordinator, - skip_propagator=self._skip_propagator, - ) - - # === Command Processing === - # Processes external commands (e.g., abort requests) - self._command_processor = CommandProcessor( - command_channel=self._command_channel, - graph_execution=self._graph_execution, - ) - - # Register command handlers - abort_handler = AbortCommandHandler() - self._command_processor.register_handler(AbortCommand, abort_handler) - - pause_handler = PauseCommandHandler() - self._command_processor.register_handler(PauseCommand, pause_handler) - - update_variables_handler = UpdateVariablesCommandHandler(self._graph_runtime_state.variable_pool) - self._command_processor.register_handler(UpdateVariablesCommand, update_variables_handler) - - # === Extensibility === - # Layers allow plugins to extend engine functionality - self._layers: list[GraphEngineLayer] = [] - - # === Worker Pool Setup === - # Capture execution context for worker threads - execution_context = capture_current_context() - - # Create worker pool for parallel node execution - self._worker_pool = WorkerPool( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - execution_context=execution_context, - config=self._config, - ) - - # === Orchestration === - # Coordinates the overall execution lifecycle - self._execution_coordinator = ExecutionCoordinator( - graph_execution=self._graph_execution, - state_manager=self._state_manager, - command_processor=self._command_processor, - worker_pool=self._worker_pool, - ) - - # === Event Handler Registry === - # Central registry for handling all node execution events - self._event_handler_registry = EventHandler( - graph=self._graph, - graph_runtime_state=self._graph_runtime_state, - graph_execution=self._graph_execution, - response_coordinator=self._response_coordinator, - event_collector=self._event_manager, - edge_processor=self._edge_processor, - state_manager=self._state_manager, - error_handler=self._error_handler, - ) - - # Dispatches events and manages execution flow - self._dispatcher = Dispatcher( - event_queue=self._event_queue, - event_handler=self._event_handler_registry, - execution_coordinator=self._execution_coordinator, - event_emitter=self._event_manager, - ) - - # === Validation === - # Ensure all nodes share the same GraphRuntimeState instance - self._validate_graph_state_consistency() - - def _validate_graph_state_consistency(self) -> None: - """Validate that all nodes share the same GraphRuntimeState.""" - expected_state_id = id(self._graph_runtime_state) - for node in self._graph.nodes.values(): - if id(node.graph_runtime_state) != expected_state_id: - raise ValueError(f"GraphRuntimeState consistency violation: Node '{node.id}' has a different instance") - - def _bind_layer_context( - self, - layer: GraphEngineLayer, - ) -> None: - layer.initialize(ReadOnlyGraphRuntimeStateWrapper(self._graph_runtime_state), self._command_channel) - - def layer(self, layer: GraphEngineLayer) -> GraphEngine: - """Add a layer for extending functionality.""" - self._layers.append(layer) - self._bind_layer_context(layer) - return self - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: dict[str, object] | Mapping[str, object], - root_node_id: str, - layers: list[GraphEngineLayer] | tuple[GraphEngineLayer, ...] = (), - ) -> GraphEngine: - return self._graph_runtime_state.create_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, - root_node_id=root_node_id, - layers=layers, - ) - - def run(self) -> Generator[GraphEngineEvent, None, None]: - """ - Execute the graph using the modular architecture. - - Returns: - Generator yielding GraphEngineEvent instances - """ - try: - # Initialize layers - self._initialize_layers() - - is_resume = self._graph_execution.started - if not is_resume: - self._graph_execution.start() - else: - self._graph_execution.paused = False - self._graph_execution.pause_reasons = [] - - start_event = GraphRunStartedEvent( - reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL, - ) - self._event_manager.notify_layers(start_event) - yield start_event - - # Start subsystems - self._start_execution(resume=is_resume) - - # Yield events as they occur - yield from self._event_manager.emit_events() - - # Handle completion - if self._graph_execution.is_paused: - pause_reasons = self._graph_execution.pause_reasons - assert pause_reasons, "pause_reasons should not be empty when execution is paused." - # Ensure we have a valid PauseReason for the event - paused_event = GraphRunPausedEvent( - reasons=pause_reasons, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(paused_event) - yield paused_event - elif self._graph_execution.aborted: - abort_reason = "Workflow execution aborted by user command" - if self._graph_execution.error: - abort_reason = str(self._graph_execution.error) - aborted_event = GraphRunAbortedEvent( - reason=abort_reason, - outputs=self._graph_runtime_state.outputs, - ) - self._event_manager.notify_layers(aborted_event) - yield aborted_event - elif self._graph_execution.has_error: - if self._graph_execution.error: - raise self._graph_execution.error - else: - outputs = self._graph_runtime_state.outputs - exceptions_count = self._graph_execution.exceptions_count - if exceptions_count > 0: - partial_event = GraphRunPartialSucceededEvent( - exceptions_count=exceptions_count, - outputs=outputs, - ) - self._event_manager.notify_layers(partial_event) - yield partial_event - else: - succeeded_event = GraphRunSucceededEvent( - outputs=outputs, - ) - self._event_manager.notify_layers(succeeded_event) - yield succeeded_event - - except Exception as e: - failed_event = GraphRunFailedEvent( - error=str(e), - exceptions_count=self._graph_execution.exceptions_count, - ) - self._event_manager.notify_layers(failed_event) - yield failed_event - raise - - finally: - self._stop_execution() - - def _initialize_layers(self) -> None: - """Initialize layers with context.""" - self._event_manager.set_layers(self._layers) - for layer in self._layers: - try: - layer.on_graph_start() - except Exception: - logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__) - - def _start_execution(self, *, resume: bool = False) -> None: - """Start execution subsystems.""" - paused_nodes: list[str] = [] - deferred_nodes: list[str] = [] - if resume: - paused_nodes = self._graph_runtime_state.consume_paused_nodes() - deferred_nodes = self._graph_runtime_state.consume_deferred_nodes() - - # Start worker pool (it calculates initial workers internally) - self._worker_pool.start() - - # Register response nodes - for node in self._graph.nodes.values(): - if node.execution_type == NodeExecutionType.RESPONSE: - self._response_coordinator.register(node.id) - - if not resume: - # Enqueue root node - root_node = self._graph.root_node - self._state_manager.enqueue_node(root_node.id) - self._state_manager.start_execution(root_node.id) - else: - seen_nodes: set[str] = set() - for node_id in paused_nodes + deferred_nodes: - if node_id in seen_nodes: - continue - seen_nodes.add(node_id) - self._state_manager.enqueue_node(node_id) - self._state_manager.start_execution(node_id) - - # Start dispatcher - self._dispatcher.start() - - def _stop_execution(self) -> None: - """Stop execution subsystems.""" - self._dispatcher.stop() - self._worker_pool.stop() - # Don't mark complete here as the dispatcher already does it - - # Notify layers - for layer in self._layers: - try: - layer.on_graph_end(self._graph_execution.error) - except Exception: - logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__) - - # Public property accessors for attributes that need external access - @property - def graph_runtime_state(self) -> GraphRuntimeState: - """Get the graph runtime state.""" - return self._graph_runtime_state diff --git a/api/dify_graph/graph_engine/graph_state_manager.py b/api/dify_graph/graph_engine/graph_state_manager.py deleted file mode 100644 index 922a968435..0000000000 --- a/api/dify_graph/graph_engine/graph_state_manager.py +++ /dev/null @@ -1,290 +0,0 @@ -""" -Graph state manager that combines node, edge, and execution tracking. -""" - -import threading -from collections.abc import Sequence -from typing import TypedDict, final - -from dify_graph.enums import NodeState -from dify_graph.graph import Edge, Graph - -from .ready_queue import ReadyQueue - - -class EdgeStateAnalysis(TypedDict): - """Analysis result for edge states.""" - - has_unknown: bool - has_taken: bool - all_skipped: bool - - -@final -class GraphStateManager: - def __init__(self, graph: Graph, ready_queue: ReadyQueue) -> None: - """ - Initialize the state manager. - - Args: - graph: The workflow graph - ready_queue: Queue for nodes ready to execute - """ - self._graph = graph - self._ready_queue = ready_queue - self._lock = threading.RLock() - - # Execution tracking state - self._executing_nodes: set[str] = set() - - # ============= Node State Operations ============= - - def enqueue_node(self, node_id: str) -> None: - """ - Mark a node as TAKEN and add it to the ready queue. - - This combines the state transition and enqueueing operations - that always occur together when preparing a node for execution. - - Args: - node_id: The ID of the node to enqueue - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.TAKEN - self._ready_queue.put(node_id) - - def mark_node_skipped(self, node_id: str) -> None: - """ - Mark a node as SKIPPED. - - Args: - node_id: The ID of the node to skip - """ - with self._lock: - self._graph.nodes[node_id].state = NodeState.SKIPPED - - def is_node_ready(self, node_id: str) -> bool: - """ - Check if a node is ready to be executed. - - A node is ready when all its incoming edges from taken branches - have been satisfied. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is ready for execution - """ - with self._lock: - # Get all incoming edges to this node - incoming_edges = self._graph.get_incoming_edges(node_id) - - # If no incoming edges, node is always ready - if not incoming_edges: - return True - - # If any edge is UNKNOWN, node is not ready - if any(edge.state == NodeState.UNKNOWN for edge in incoming_edges): - return False - - # Node is ready if at least one edge is TAKEN - return any(edge.state == NodeState.TAKEN for edge in incoming_edges) - - def get_node_state(self, node_id: str) -> NodeState: - """ - Get the current state of a node. - - Args: - node_id: The ID of the node - - Returns: - The current node state - """ - with self._lock: - return self._graph.nodes[node_id].state - - # ============= Edge State Operations ============= - - def mark_edge_taken(self, edge_id: str) -> None: - """ - Mark an edge as TAKEN. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.TAKEN - - def mark_edge_skipped(self, edge_id: str) -> None: - """ - Mark an edge as SKIPPED. - - Args: - edge_id: The ID of the edge to mark - """ - with self._lock: - self._graph.edges[edge_id].state = NodeState.SKIPPED - - def analyze_edge_states(self, edges: list[Edge]) -> EdgeStateAnalysis: - """ - Analyze the states of edges and return summary flags. - - Args: - edges: List of edges to analyze - - Returns: - Analysis result with state flags - """ - with self._lock: - states = {edge.state for edge in edges} - - return EdgeStateAnalysis( - has_unknown=NodeState.UNKNOWN in states, - has_taken=NodeState.TAKEN in states, - all_skipped=states == {NodeState.SKIPPED} if states else True, - ) - - def get_edge_state(self, edge_id: str) -> NodeState: - """ - Get the current state of an edge. - - Args: - edge_id: The ID of the edge - - Returns: - The current edge state - """ - with self._lock: - return self._graph.edges[edge_id].state - - def categorize_branch_edges(self, node_id: str, selected_handle: str) -> tuple[Sequence[Edge], Sequence[Edge]]: - """ - Categorize branch edges into selected and unselected. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - A tuple of (selected_edges, unselected_edges) - """ - with self._lock: - outgoing_edges = self._graph.get_outgoing_edges(node_id) - selected_edges: list[Edge] = [] - unselected_edges: list[Edge] = [] - - for edge in outgoing_edges: - if edge.source_handle == selected_handle: - selected_edges.append(edge) - else: - unselected_edges.append(edge) - - return selected_edges, unselected_edges - - # ============= Execution Tracking Operations ============= - - def start_execution(self, node_id: str) -> None: - """ - Mark a node as executing. - - Args: - node_id: The ID of the node starting execution - """ - with self._lock: - self._executing_nodes.add(node_id) - - def finish_execution(self, node_id: str) -> None: - """ - Mark a node as no longer executing. - - Args: - node_id: The ID of the node finishing execution - """ - with self._lock: - self._executing_nodes.discard(node_id) - - def is_executing(self, node_id: str) -> bool: - """ - Check if a node is currently executing. - - Args: - node_id: The ID of the node to check - - Returns: - True if the node is executing - """ - with self._lock: - return node_id in self._executing_nodes - - def get_executing_count(self) -> int: - """ - Get the count of currently executing nodes. - - Returns: - Number of executing nodes - """ - # This count is a best-effort snapshot and can change concurrently. - # Only use it for pause-drain checks where scheduling is already frozen. - with self._lock: - return len(self._executing_nodes) - - def get_executing_nodes(self) -> set[str]: - """ - Get a copy of the set of executing node IDs. - - Returns: - Set of node IDs currently executing - """ - with self._lock: - return self._executing_nodes.copy() - - def clear_executing(self) -> None: - """Clear all executing nodes.""" - with self._lock: - self._executing_nodes.clear() - - # ============= Composite Operations ============= - - def is_execution_complete(self) -> bool: - """ - Check if graph execution is complete. - - Execution is complete when: - - Ready queue is empty - - No nodes are executing - - Returns: - True if execution is complete - """ - with self._lock: - return self._ready_queue.empty() and len(self._executing_nodes) == 0 - - def get_queue_depth(self) -> int: - """ - Get the current depth of the ready queue. - - Returns: - Number of nodes in the ready queue - """ - return self._ready_queue.qsize() - - def get_execution_stats(self) -> dict[str, int]: - """ - Get execution statistics. - - Returns: - Dictionary with execution statistics - """ - with self._lock: - taken_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.TAKEN) - skipped_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.SKIPPED) - unknown_nodes = sum(1 for node in self._graph.nodes.values() if node.state == NodeState.UNKNOWN) - - return { - "queue_depth": self._ready_queue.qsize(), - "executing": len(self._executing_nodes), - "taken_nodes": taken_nodes, - "skipped_nodes": skipped_nodes, - "unknown_nodes": unknown_nodes, - } diff --git a/api/dify_graph/graph_engine/graph_traversal/__init__.py b/api/dify_graph/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index d629140d06..0000000000 --- a/api/dify_graph/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Graph traversal subsystem for graph engine. - -This package handles graph navigation, edge processing, -and skip propagation logic. -""" - -from .edge_processor import EdgeProcessor -from .skip_propagator import SkipPropagator - -__all__ = [ - "EdgeProcessor", - "SkipPropagator", -] diff --git a/api/dify_graph/graph_engine/graph_traversal/edge_processor.py b/api/dify_graph/graph_engine/graph_traversal/edge_processor.py deleted file mode 100644 index c4625a8ff7..0000000000 --- a/api/dify_graph/graph_engine/graph_traversal/edge_processor.py +++ /dev/null @@ -1,201 +0,0 @@ -""" -Edge processing logic for graph traversal. -""" - -from collections.abc import Sequence -from typing import TYPE_CHECKING, final - -from dify_graph.enums import NodeExecutionType -from dify_graph.graph import Edge, Graph -from dify_graph.graph_events import NodeRunStreamChunkEvent - -from ..graph_state_manager import GraphStateManager -from ..response_coordinator import ResponseStreamCoordinator - -if TYPE_CHECKING: - from .skip_propagator import SkipPropagator - - -@final -class EdgeProcessor: - """ - Processes edges during graph execution. - - This handles marking edges as taken or skipped, notifying - the response coordinator, triggering downstream node execution, - and managing branch node logic. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - response_coordinator: ResponseStreamCoordinator, - skip_propagator: "SkipPropagator", - ) -> None: - """ - Initialize the edge processor. - - Args: - graph: The workflow graph - state_manager: Unified state manager - response_coordinator: Response stream coordinator - skip_propagator: Propagator for skip states - """ - self._graph = graph - self._state_manager = state_manager - self._response_coordinator = response_coordinator - self._skip_propagator = skip_propagator - - def process_node_success( - self, node_id: str, selected_handle: str | None = None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges after a node succeeds. - - Args: - node_id: The ID of the succeeded node - selected_handle: For branch nodes, the selected edge handle - - Returns: - Tuple of (list of downstream node IDs that are now ready, list of streaming events) - """ - node = self._graph.nodes[node_id] - - if node.execution_type == NodeExecutionType.BRANCH: - return self._process_branch_node_edges(node_id, selected_handle) - else: - return self._process_non_branch_node_edges(node_id) - - def _process_non_branch_node_edges(self, node_id: str) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for non-branch nodes (mark all as TAKEN). - - Args: - node_id: The ID of the succeeded node - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - """ - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - outgoing_edges = self._graph.get_outgoing_edges(node_id) - - for edge in outgoing_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_branch_node_edges( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Process edges for branch nodes. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected edge - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no edge was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} did not select any edge") - - ready_nodes: list[str] = [] - all_streaming_events: list[NodeRunStreamChunkEvent] = [] - - # Categorize edges - selected_edges, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Process unselected edges first (mark as skipped) - for edge in unselected_edges: - self._process_skipped_edge(edge) - - # Process selected edges - for edge in selected_edges: - nodes, events = self._process_taken_edge(edge) - ready_nodes.extend(nodes) - all_streaming_events.extend(events) - - return ready_nodes, all_streaming_events - - def _process_taken_edge(self, edge: Edge) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Mark edge as taken and check downstream node. - - Args: - edge: The edge to process - - Returns: - Tuple of (list containing downstream node ID if it's ready, list of streaming events) - """ - # Mark edge as taken - self._state_manager.mark_edge_taken(edge.id) - - # Notify response coordinator and get streaming events - streaming_events = self._response_coordinator.on_edge_taken(edge.id) - - # Check if downstream node is ready - ready_nodes: list[str] = [] - if self._state_manager.is_node_ready(edge.head): - ready_nodes.append(edge.head) - - return ready_nodes, streaming_events - - def _process_skipped_edge(self, edge: Edge) -> None: - """ - Mark edge as skipped. - - Args: - edge: The edge to skip - """ - self._state_manager.mark_edge_skipped(edge.id) - - def handle_branch_completion( - self, node_id: str, selected_handle: str | None - ) -> tuple[Sequence[str], Sequence[NodeRunStreamChunkEvent]]: - """ - Handle completion of a branch node. - - Args: - node_id: The ID of the branch node - selected_handle: The handle of the selected branch - - Returns: - Tuple of (list of downstream nodes ready for execution, list of streaming events) - - Raises: - ValueError: If no branch was selected - """ - if not selected_handle: - raise ValueError(f"Branch node {node_id} completed without selecting a branch") - - # Categorize edges into selected and unselected - _, unselected_edges = self._state_manager.categorize_branch_edges(node_id, selected_handle) - - # Skip all unselected paths - self._skip_propagator.skip_branch_paths(unselected_edges) - - # Process selected edges and get ready nodes and streaming events - return self.process_node_success(node_id, selected_handle) - - def validate_branch_selection(self, node_id: str, selected_handle: str) -> bool: - """ - Validate that a branch selection is valid. - - Args: - node_id: The ID of the branch node - selected_handle: The handle to validate - - Returns: - True if the selection is valid - """ - outgoing_edges = self._graph.get_outgoing_edges(node_id) - valid_handles = {edge.source_handle for edge in outgoing_edges} - return selected_handle in valid_handles diff --git a/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py b/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py deleted file mode 100644 index 76445bccd2..0000000000 --- a/api/dify_graph/graph_engine/graph_traversal/skip_propagator.py +++ /dev/null @@ -1,96 +0,0 @@ -""" -Skip state propagation through the graph. -""" - -from collections.abc import Sequence -from typing import final - -from dify_graph.graph import Edge, Graph - -from ..graph_state_manager import GraphStateManager - - -@final -class SkipPropagator: - """ - Propagates skip states through the graph. - - When a node is skipped, this ensures all downstream nodes - that depend solely on it are also skipped. - """ - - def __init__( - self, - graph: Graph, - state_manager: GraphStateManager, - ) -> None: - """ - Initialize the skip propagator. - - Args: - graph: The workflow graph - state_manager: Unified state manager - """ - self._graph = graph - self._state_manager = state_manager - - def propagate_skip_from_edge(self, edge_id: str) -> None: - """ - Recursively propagate skip state from a skipped edge. - - Rules: - - If a node has any UNKNOWN incoming edges, stop processing - - If all incoming edges are SKIPPED, skip the node and its edges - - If any incoming edge is TAKEN, the node may still execute - - Args: - edge_id: The ID of the skipped edge to start from - """ - downstream_node_id = self._graph.edges[edge_id].head - incoming_edges = self._graph.get_incoming_edges(downstream_node_id) - - # Analyze edge states - edge_states = self._state_manager.analyze_edge_states(incoming_edges) - - # Stop if there are unknown edges (not yet processed) - if edge_states["has_unknown"]: - return - - # If any edge is taken, node may still execute - if edge_states["has_taken"]: - # Enqueue node - self._state_manager.enqueue_node(downstream_node_id) - self._state_manager.start_execution(downstream_node_id) - return - - # All edges are skipped, propagate skip to this node - if edge_states["all_skipped"]: - self._propagate_skip_to_node(downstream_node_id) - - def _propagate_skip_to_node(self, node_id: str) -> None: - """ - Mark a node and all its outgoing edges as skipped. - - Args: - node_id: The ID of the node to skip - """ - # Mark node as skipped - self._state_manager.mark_node_skipped(node_id) - - # Mark all outgoing edges as skipped and propagate - outgoing_edges = self._graph.get_outgoing_edges(node_id) - for edge in outgoing_edges: - self._state_manager.mark_edge_skipped(edge.id) - # Recursively propagate skip - self.propagate_skip_from_edge(edge.id) - - def skip_branch_paths(self, unselected_edges: Sequence[Edge]) -> None: - """ - Skip all paths from unselected branch edges. - - Args: - unselected_edges: List of edges not taken by the branch - """ - for edge in unselected_edges: - self._state_manager.mark_edge_skipped(edge.id) - self.propagate_skip_from_edge(edge.id) diff --git a/api/dify_graph/graph_engine/layers/README.md b/api/dify_graph/graph_engine/layers/README.md deleted file mode 100644 index b0f295037c..0000000000 --- a/api/dify_graph/graph_engine/layers/README.md +++ /dev/null @@ -1,55 +0,0 @@ -# Layers - -Pluggable middleware for engine extensions. - -## Components - -### Layer (base) - -Abstract base class for layers. - -- `initialize()` - Receive runtime context (runtime state is bound here and always available to hooks) -- `on_graph_start()` - Execution start hook -- `on_event()` - Process all events -- `on_graph_end()` - Execution end hook - -### DebugLoggingLayer - -Comprehensive execution logging. - -- Configurable detail levels -- Tracks execution statistics -- Truncates long values - -## Usage - -```python -debug_layer = DebugLoggingLayer( - level="INFO", - include_outputs=True -) - -engine = GraphEngine(graph) -engine.layer(debug_layer) -engine.run() -``` - -`engine.layer()` binds the read-only runtime state before execution, so -`graph_runtime_state` is always available inside layer hooks. - -## Custom Layers - -```python -class MetricsLayer(Layer): - def on_event(self, event): - if isinstance(event, NodeRunSucceededEvent): - self.metrics[event.node_id] = event.elapsed_time -``` - -## Configuration - -**DebugLoggingLayer Options:** - -- `level` - Log level (INFO, DEBUG, ERROR) -- `include_inputs/outputs` - Log data values -- `max_value_length` - Truncate long values diff --git a/api/dify_graph/graph_engine/layers/__init__.py b/api/dify_graph/graph_engine/layers/__init__.py deleted file mode 100644 index 0a29a52993..0000000000 --- a/api/dify_graph/graph_engine/layers/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -""" -Layer system for GraphEngine extensibility. - -This module provides the layer infrastructure for extending GraphEngine functionality -with middleware-like components that can observe events and interact with execution. -""" - -from .base import GraphEngineLayer -from .debug_logging import DebugLoggingLayer -from .execution_limits import ExecutionLimitsLayer - -__all__ = [ - "DebugLoggingLayer", - "ExecutionLimitsLayer", - "GraphEngineLayer", -] diff --git a/api/dify_graph/graph_engine/layers/base.py b/api/dify_graph/graph_engine/layers/base.py deleted file mode 100644 index 890336c1ca..0000000000 --- a/api/dify_graph/graph_engine/layers/base.py +++ /dev/null @@ -1,128 +0,0 @@ -""" -Base layer class for GraphEngine extensions. - -This module provides the abstract base class for implementing layers that can -intercept and respond to GraphEngine events. -""" - -from abc import ABC, abstractmethod - -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events import GraphEngineEvent, GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import ReadOnlyGraphRuntimeState - - -class GraphEngineLayerNotInitializedError(Exception): - """Raised when a layer's runtime state is accessed before initialization.""" - - def __init__(self, layer_name: str | None = None) -> None: - name = layer_name or "GraphEngineLayer" - super().__init__(f"{name} runtime state is not initialized. Bind the layer to a GraphEngine before access.") - - -class GraphEngineLayer(ABC): - """ - Abstract base class for GraphEngine layers. - - Layers are middleware-like components that can: - - Observe all events emitted by the GraphEngine - - Access the graph runtime state - - Send commands to control execution - - Subclasses should override the constructor to accept configuration parameters, - then implement the three lifecycle methods. - """ - - def __init__(self) -> None: - """Initialize the layer. Subclasses can override with custom parameters.""" - self._graph_runtime_state: ReadOnlyGraphRuntimeState | None = None - self.command_channel: CommandChannel | None = None - - @property - def graph_runtime_state(self) -> ReadOnlyGraphRuntimeState: - if self._graph_runtime_state is None: - raise GraphEngineLayerNotInitializedError(type(self).__name__) - return self._graph_runtime_state - - def initialize(self, graph_runtime_state: ReadOnlyGraphRuntimeState, command_channel: CommandChannel) -> None: - """ - Initialize the layer with engine dependencies. - - Called by GraphEngine to inject the read-only runtime state and command channel. - This is invoked when the layer is registered with a `GraphEngine` instance. - Implementations should be idempotent. - Args: - graph_runtime_state: Read-only view of the runtime state - command_channel: Channel for sending commands to the engine - """ - self._graph_runtime_state = graph_runtime_state - self.command_channel = command_channel - - @abstractmethod - def on_graph_start(self) -> None: - """ - Called when graph execution starts. - - This is called after the engine has been initialized but before any nodes - are executed. Layers can use this to set up resources or log start information. - """ - pass - - @abstractmethod - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - This method receives all events generated during graph execution, including: - - Graph lifecycle events (start, success, failure) - - Node execution events (start, success, failure, retry) - - Stream events for response nodes - - Container events (iteration, loop) - - Args: - event: The event emitted by the engine - """ - pass - - @abstractmethod - def on_graph_end(self, error: Exception | None) -> None: - """ - Called when graph execution ends. - - This is called after all nodes have been executed or when execution is - aborted. Layers can use this to clean up resources or log final state. - - Args: - error: The exception that caused execution to fail, or None if successful - """ - pass - - def on_node_run_start(self, node: Node) -> None: - """ - Called immediately before a node begins execution. - - Layers can override to inject behavior (e.g., start spans) prior to node execution. - The node's execution ID is available via `node._node_execution_id` and will be - consistent with all events emitted by this node execution. - - Args: - node: The node instance about to be executed - """ - return - - def on_node_run_end( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """ - Called after a node finishes execution. - - The node's execution ID is available via `node._node_execution_id` and matches - the `id` field in all events emitted by this node execution. - - Args: - node: The node instance that just finished execution - error: Exception instance if the node failed, otherwise None - result_event: The final result event from node execution (succeeded/failed/paused), if any - """ - return diff --git a/api/dify_graph/graph_engine/layers/debug_logging.py b/api/dify_graph/graph_engine/layers/debug_logging.py deleted file mode 100644 index 1af2e2db9e..0000000000 --- a/api/dify_graph/graph_engine/layers/debug_logging.py +++ /dev/null @@ -1,247 +0,0 @@ -""" -Debug logging layer for GraphEngine. - -This module provides a layer that logs all events and state changes during -graph execution for debugging purposes. -""" - -import logging -from collections.abc import Mapping -from typing import Any, final - -from typing_extensions import override - -from dify_graph.graph_events import ( - GraphEngineEvent, - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .base import GraphEngineLayer - - -@final -class DebugLoggingLayer(GraphEngineLayer): - """ - A layer that provides comprehensive logging of GraphEngine execution. - - This layer logs all events with configurable detail levels, helping developers - debug workflow execution and understand the flow of events. - """ - - def __init__( - self, - level: str = "INFO", - include_inputs: bool = False, - include_outputs: bool = True, - include_process_data: bool = False, - logger_name: str = "GraphEngine.Debug", - max_value_length: int = 500, - ) -> None: - """ - Initialize the debug logging layer. - - Args: - level: Logging level (DEBUG, INFO, WARNING, ERROR) - include_inputs: Whether to log node input values - include_outputs: Whether to log node output values - include_process_data: Whether to log node process data - logger_name: Name of the logger to use - max_value_length: Maximum length of logged values (truncated if longer) - """ - super().__init__() - self.level = level - self.include_inputs = include_inputs - self.include_outputs = include_outputs - self.include_process_data = include_process_data - self.max_value_length = max_value_length - - # Set up logger - self.logger = logging.getLogger(logger_name) - log_level = getattr(logging, level.upper(), logging.INFO) - self.logger.setLevel(log_level) - - # Track execution stats - self.node_count = 0 - self.success_count = 0 - self.failure_count = 0 - self.retry_count = 0 - - def _truncate_value(self, value: Any) -> str: - """Truncate long values for logging.""" - str_value = str(value) - if len(str_value) > self.max_value_length: - return str_value[: self.max_value_length] + "... (truncated)" - return str_value - - def _format_dict(self, data: dict[str, Any] | Mapping[str, Any]) -> str: - """Format a dictionary or mapping for logging with truncation.""" - if not data: - return "{}" - - formatted_items: list[str] = [] - for key, value in data.items(): - formatted_value = self._truncate_value(value) - formatted_items.append(f" {key}: {formatted_value}") - - return "{\n" + ",\n".join(formatted_items) + "\n}" - - @override - def on_graph_start(self) -> None: - """Log graph execution start.""" - self.logger.info("=" * 80) - self.logger.info("🚀 GRAPH EXECUTION STARTED") - self.logger.info("=" * 80) - # Log initial state - self.logger.info("Initial State:") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """Log individual events based on their type.""" - event_class = event.__class__.__name__ - - # Graph-level events - if isinstance(event, GraphRunStartedEvent): - self.logger.debug("Graph run started event") - - elif isinstance(event, GraphRunSucceededEvent): - self.logger.info("✅ Graph run succeeded") - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunPartialSucceededEvent): - self.logger.warning("⚠️ Graph run partially succeeded") - if event.exceptions_count > 0: - self.logger.warning(" Total exceptions: %s", event.exceptions_count) - if self.include_outputs and event.outputs: - self.logger.info(" Final outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, GraphRunFailedEvent): - self.logger.error("❌ Graph run failed: %s", event.error) - if event.exceptions_count > 0: - self.logger.error(" Total exceptions: %s", event.exceptions_count) - - elif isinstance(event, GraphRunAbortedEvent): - self.logger.warning("⚠️ Graph run aborted: %s", event.reason) - if event.outputs: - self.logger.info(" Partial outputs: %s", self._format_dict(event.outputs)) - - # Node-level events - # Retry before Started because Retry subclasses Started; - elif isinstance(event, NodeRunRetryEvent): - self.retry_count += 1 - self.logger.warning("🔄 Node retry: %s (attempt %s)", event.node_id, event.retry_index) - self.logger.warning(" Previous error: %s", event.error) - - elif isinstance(event, NodeRunStartedEvent): - self.node_count += 1 - self.logger.info('▶️ Node started: %s - "%s" (type: %s)', event.node_id, event.node_title, event.node_type) - - if self.include_inputs and event.node_run_result.inputs: - self.logger.debug(" Inputs: %s", self._format_dict(event.node_run_result.inputs)) - - elif isinstance(event, NodeRunSucceededEvent): - self.success_count += 1 - self.logger.info("✅ Node succeeded: %s", event.node_id) - - if self.include_outputs and event.node_run_result.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.node_run_result.outputs)) - - if self.include_process_data and event.node_run_result.process_data: - self.logger.debug(" Process data: %s", self._format_dict(event.node_run_result.process_data)) - - elif isinstance(event, NodeRunFailedEvent): - self.failure_count += 1 - self.logger.error("❌ Node failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - if event.node_run_result.error: - self.logger.error(" Details: %s", event.node_run_result.error) - - elif isinstance(event, NodeRunExceptionEvent): - self.logger.warning("⚠️ Node exception handled: %s", event.node_id) - self.logger.warning(" Error: %s", event.error) - - elif isinstance(event, NodeRunStreamChunkEvent): - # Log stream chunks at debug level to avoid spam - final_indicator = " (FINAL)" if event.is_final else "" - self.logger.debug( - "📝 Stream chunk from %s%s: %s", event.node_id, final_indicator, self._truncate_value(event.chunk) - ) - - # Iteration events - elif isinstance(event, NodeRunIterationStartedEvent): - self.logger.info("🔁 Iteration started: %s", event.node_id) - - elif isinstance(event, NodeRunIterationNextEvent): - self.logger.debug(" Iteration next: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunIterationSucceededEvent): - self.logger.info("✅ Iteration succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunIterationFailedEvent): - self.logger.error("❌ Iteration failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - # Loop events - elif isinstance(event, NodeRunLoopStartedEvent): - self.logger.info("🔄 Loop started: %s", event.node_id) - - elif isinstance(event, NodeRunLoopNextEvent): - self.logger.debug(" Loop iteration: %s (index: %s)", event.node_id, event.index) - - elif isinstance(event, NodeRunLoopSucceededEvent): - self.logger.info("✅ Loop succeeded: %s", event.node_id) - if self.include_outputs and event.outputs: - self.logger.debug(" Outputs: %s", self._format_dict(event.outputs)) - - elif isinstance(event, NodeRunLoopFailedEvent): - self.logger.error("❌ Loop failed: %s", event.node_id) - self.logger.error(" Error: %s", event.error) - - else: - # Log unknown events at debug level - self.logger.debug("Event: %s", event_class) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Log graph execution end with summary statistics.""" - self.logger.info("=" * 80) - - if error: - self.logger.error("🔴 GRAPH EXECUTION FAILED") - self.logger.error(" Error: %s", error) - else: - self.logger.info("🎉 GRAPH EXECUTION COMPLETED SUCCESSFULLY") - - # Log execution statistics - self.logger.info("Execution Statistics:") - self.logger.info(" Total nodes executed: %s", self.node_count) - self.logger.info(" Successful nodes: %s", self.success_count) - self.logger.info(" Failed nodes: %s", self.failure_count) - self.logger.info(" Node retries: %s", self.retry_count) - - # Log final state if available - if self.include_outputs and self.graph_runtime_state.outputs: - self.logger.info("Final outputs: %s", self._format_dict(self.graph_runtime_state.outputs)) - - self.logger.info("=" * 80) diff --git a/api/dify_graph/graph_engine/layers/execution_limits.py b/api/dify_graph/graph_engine/layers/execution_limits.py deleted file mode 100644 index 48ba5608d9..0000000000 --- a/api/dify_graph/graph_engine/layers/execution_limits.py +++ /dev/null @@ -1,150 +0,0 @@ -""" -Execution limits layer for GraphEngine. - -This layer monitors workflow execution to enforce limits on: -- Maximum execution steps -- Maximum execution time - -When limits are exceeded, the layer automatically aborts execution. -""" - -import logging -import time -from enum import StrEnum -from typing import final - -from typing_extensions import override - -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType -from dify_graph.graph_engine.layers import GraphEngineLayer -from dify_graph.graph_events import ( - GraphEngineEvent, - NodeRunStartedEvent, -) -from dify_graph.graph_events.node import NodeRunFailedEvent, NodeRunSucceededEvent - - -class LimitType(StrEnum): - """Types of execution limits that can be exceeded.""" - - STEP_LIMIT = "step_limit" - TIME_LIMIT = "time_limit" - - -@final -class ExecutionLimitsLayer(GraphEngineLayer): - """ - Layer that enforces execution limits for workflows. - - Monitors: - - Step count: Tracks number of node executions - - Time limit: Monitors total execution time - - Automatically aborts execution when limits are exceeded. - """ - - def __init__(self, max_steps: int, max_time: int) -> None: - """ - Initialize the execution limits layer. - - Args: - max_steps: Maximum number of execution steps allowed - max_time: Maximum execution time in seconds allowed - """ - super().__init__() - self.max_steps = max_steps - self.max_time = max_time - - # Runtime tracking - self.start_time: float | None = None - self.step_count = 0 - self.logger = logging.getLogger(__name__) - - # State tracking - self._execution_started = False - self._execution_ended = False - self._abort_sent = False # Track if abort command has been sent - - @override - def on_graph_start(self) -> None: - """Called when graph execution starts.""" - self.start_time = time.time() - self.step_count = 0 - self._execution_started = True - self._execution_ended = False - self._abort_sent = False - - self.logger.debug("Execution limits monitoring started") - - @override - def on_event(self, event: GraphEngineEvent) -> None: - """ - Called for every event emitted by the engine. - - Monitors execution progress and enforces limits. - """ - if not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Track step count for node execution events - if isinstance(event, NodeRunStartedEvent): - self.step_count += 1 - self.logger.debug("Step %d started: %s", self.step_count, event.node_id) - - # Check step limit when node execution completes - if isinstance(event, NodeRunSucceededEvent | NodeRunFailedEvent): - if self._reached_step_limitation(): - self._send_abort_command(LimitType.STEP_LIMIT) - - if self._reached_time_limitation(): - self._send_abort_command(LimitType.TIME_LIMIT) - - @override - def on_graph_end(self, error: Exception | None) -> None: - """Called when graph execution ends.""" - if self._execution_started and not self._execution_ended: - self._execution_ended = True - - if self.start_time: - total_time = time.time() - self.start_time - self.logger.debug("Execution completed: %d steps in %.2f seconds", self.step_count, total_time) - - def _reached_step_limitation(self) -> bool: - """Check if step count limit has been exceeded.""" - return self.step_count > self.max_steps - - def _reached_time_limitation(self) -> bool: - """Check if time limit has been exceeded.""" - return self.start_time is not None and (time.time() - self.start_time) > self.max_time - - def _send_abort_command(self, limit_type: LimitType) -> None: - """ - Send abort command due to limit violation. - - Args: - limit_type: Type of limit exceeded - """ - if not self.command_channel or not self._execution_started or self._execution_ended or self._abort_sent: - return - - # Format detailed reason message - if limit_type == LimitType.STEP_LIMIT: - reason = f"Maximum execution steps exceeded: {self.step_count} > {self.max_steps}" - elif limit_type == LimitType.TIME_LIMIT: - elapsed_time = time.time() - self.start_time if self.start_time else 0 - reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s" - - self.logger.warning("Execution limit exceeded: %s", reason) - - try: - # Send abort command to the engine - abort_command = AbortCommand(command_type=CommandType.ABORT, reason=reason) - self.command_channel.send_command(abort_command) - - # Mark that abort has been sent to prevent duplicate commands - self._abort_sent = True - - self.logger.debug("Abort command sent to engine") - - except Exception: - self.logger.exception("Failed to send abort command") diff --git a/api/dify_graph/graph_engine/manager.py b/api/dify_graph/graph_engine/manager.py deleted file mode 100644 index 955c149069..0000000000 --- a/api/dify_graph/graph_engine/manager.py +++ /dev/null @@ -1,79 +0,0 @@ -""" -GraphEngine Manager for sending control commands via Redis channel. - -This module provides a simplified interface for controlling workflow executions -using the new Redis command channel, without requiring user permission checks. -Callers must provide a Redis client dependency from outside the workflow package. -""" - -import logging -from collections.abc import Sequence -from typing import final - -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel, RedisClientProtocol -from dify_graph.graph_engine.entities.commands import ( - AbortCommand, - GraphEngineCommand, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) - -logger = logging.getLogger(__name__) - - -@final -class GraphEngineManager: - """ - Manager for sending control commands to GraphEngine instances. - - This class provides a simple interface for controlling workflow executions - by sending commands through Redis channels, without user validation. - """ - - _redis_client: RedisClientProtocol - - def __init__(self, redis_client: RedisClientProtocol) -> None: - self._redis_client = redis_client - - def send_stop_command(self, task_id: str, reason: str | None = None) -> None: - """ - Send a stop command to a running workflow. - - Args: - task_id: The task ID of the workflow to stop - reason: Optional reason for stopping (defaults to "User requested stop") - """ - abort_command = AbortCommand(reason=reason or "User requested stop") - self._send_command(task_id, abort_command) - - def send_pause_command(self, task_id: str, reason: str | None = None) -> None: - """Send a pause command to a running workflow.""" - - pause_command = PauseCommand(reason=reason or "User requested pause") - self._send_command(task_id, pause_command) - - def send_update_variables_command(self, task_id: str, updates: Sequence[VariableUpdate]) -> None: - """Send a command to update variables in a running workflow.""" - - if not updates: - return - - update_command = UpdateVariablesCommand(updates=updates) - self._send_command(task_id, update_command) - - def _send_command(self, task_id: str, command: GraphEngineCommand) -> None: - """Send a command to the workflow-specific Redis channel.""" - - if not task_id: - return - - channel_key = f"workflow:{task_id}:commands" - channel = RedisChannel(self._redis_client, channel_key) - - try: - channel.send_command(command) - except Exception: - # Silently fail if Redis is unavailable - # The legacy control mechanisms will still work - logger.exception("Failed to send graph engine command %s for task %s", command.__class__.__name__, task_id) diff --git a/api/dify_graph/graph_engine/orchestration/__init__.py b/api/dify_graph/graph_engine/orchestration/__init__.py deleted file mode 100644 index de08e942fb..0000000000 --- a/api/dify_graph/graph_engine/orchestration/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Orchestration subsystem for graph engine. - -This package coordinates the overall execution flow between -different subsystems. -""" - -from .dispatcher import Dispatcher -from .execution_coordinator import ExecutionCoordinator - -__all__ = [ - "Dispatcher", - "ExecutionCoordinator", -] diff --git a/api/dify_graph/graph_engine/orchestration/dispatcher.py b/api/dify_graph/graph_engine/orchestration/dispatcher.py deleted file mode 100644 index f8aaf20b2f..0000000000 --- a/api/dify_graph/graph_engine/orchestration/dispatcher.py +++ /dev/null @@ -1,143 +0,0 @@ -""" -Main dispatcher for processing events from workers. -""" - -import logging -import queue -import threading -import time -from typing import TYPE_CHECKING, final - -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunSucceededEvent, -) - -from ..event_management import EventManager -from .execution_coordinator import ExecutionCoordinator - -if TYPE_CHECKING: - from ..event_management import EventHandler - -logger = logging.getLogger(__name__) - - -@final -class Dispatcher: - """ - Main dispatcher that processes events from the event queue. - - This runs in a separate thread and coordinates event processing - with timeout and completion detection. - """ - - _COMMAND_TRIGGER_EVENTS = ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunExceptionEvent, - ) - - def __init__( - self, - event_queue: queue.Queue[GraphNodeEventBase], - event_handler: "EventHandler", - execution_coordinator: ExecutionCoordinator, - event_emitter: EventManager | None = None, - ) -> None: - """ - Initialize the dispatcher. - - Args: - event_queue: Queue of events from workers - event_handler: Event handler registry for processing events - execution_coordinator: Coordinator for execution flow - event_emitter: Optional event manager to signal completion - """ - self._event_queue = event_queue - self._event_handler = event_handler - self._execution_coordinator = execution_coordinator - self._event_emitter = event_emitter - - self._thread: threading.Thread | None = None - self._stop_event = threading.Event() - self._start_time: float | None = None - - def start(self) -> None: - """Start the dispatcher thread.""" - if self._thread and self._thread.is_alive(): - return - - self._stop_event.clear() - self._start_time = time.time() - self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True) - self._thread.start() - - def stop(self) -> None: - """Stop the dispatcher thread.""" - self._stop_event.set() - if self._thread and self._thread.is_alive(): - self._thread.join(timeout=2.0) - - def _dispatcher_loop(self) -> None: - """Main dispatcher loop.""" - try: - self._process_commands() - paused = False - while not self._stop_event.is_set(): - if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete: - break - if self._execution_coordinator.paused: - paused = True - break - - self._execution_coordinator.check_scaling() - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - time.sleep(0.1) - - self._process_commands() - if paused: - self._drain_events_until_idle() - else: - self._drain_event_queue() - - except Exception as e: - logger.exception("Dispatcher error") - self._execution_coordinator.mark_failed(e) - - finally: - self._execution_coordinator.mark_complete() - # Signal the event emitter that execution is complete - if self._event_emitter: - self._event_emitter.mark_complete() - - def _process_commands(self, event: GraphNodeEventBase | None = None): - if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS): - self._execution_coordinator.process_commands() - - def _drain_event_queue(self) -> None: - while True: - try: - event = self._event_queue.get(block=False) - self._event_handler.dispatch(event) - self._event_queue.task_done() - except queue.Empty: - break - - def _drain_events_until_idle(self) -> None: - while not self._stop_event.is_set(): - try: - event = self._event_queue.get(timeout=0.1) - self._event_handler.dispatch(event) - self._event_queue.task_done() - self._process_commands(event) - except queue.Empty: - if not self._execution_coordinator.has_executing_nodes(): - break - self._drain_event_queue() diff --git a/api/dify_graph/graph_engine/orchestration/execution_coordinator.py b/api/dify_graph/graph_engine/orchestration/execution_coordinator.py deleted file mode 100644 index 0f8550eb12..0000000000 --- a/api/dify_graph/graph_engine/orchestration/execution_coordinator.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -Execution coordinator for managing overall workflow execution. -""" - -from typing import final - -from ..command_processing import CommandProcessor -from ..domain import GraphExecution -from ..graph_state_manager import GraphStateManager -from ..worker_management import WorkerPool - - -@final -class ExecutionCoordinator: - """ - Coordinates overall execution flow between subsystems. - - This provides high-level coordination methods used by the - dispatcher to manage execution state. - """ - - def __init__( - self, - graph_execution: GraphExecution, - state_manager: GraphStateManager, - command_processor: CommandProcessor, - worker_pool: WorkerPool, - ) -> None: - """ - Initialize the execution coordinator. - - Args: - graph_execution: Graph execution aggregate - state_manager: Unified state manager - command_processor: Processor for commands - worker_pool: Pool of workers - """ - self._graph_execution = graph_execution - self._state_manager = state_manager - self._command_processor = command_processor - self._worker_pool = worker_pool - - def process_commands(self) -> None: - """Process any pending commands.""" - self._command_processor.process_commands() - - def check_scaling(self) -> None: - """Check and perform worker scaling if needed.""" - self._worker_pool.check_and_scale() - - @property - def execution_complete(self): - return self._state_manager.is_execution_complete() - - @property - def aborted(self): - return self._graph_execution.aborted or self._graph_execution.has_error - - @property - def paused(self) -> bool: - """Expose whether the underlying graph execution is paused.""" - return self._graph_execution.is_paused - - def mark_complete(self) -> None: - """Mark execution as complete.""" - if self._graph_execution.is_paused: - return - if not self._graph_execution.completed: - self._graph_execution.complete() - - def mark_failed(self, error: Exception) -> None: - """ - Mark execution as failed. - - Args: - error: The error that caused failure - """ - self._graph_execution.fail(error) - - def handle_pause_if_needed(self) -> None: - """If the execution has been paused, stop workers immediately.""" - - if not self._graph_execution.is_paused: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def handle_abort_if_needed(self) -> None: - """If the execution has been aborted, stop workers immediately.""" - - if not self._graph_execution.aborted: - return - - self._worker_pool.stop() - self._state_manager.clear_executing() - - def has_executing_nodes(self) -> bool: - """Return True if any nodes are currently marked as executing.""" - # This check is only safe once execution has already paused. - # Before pause, executing state can change concurrently, which makes the result unreliable. - if not self._graph_execution.is_paused: - raise AssertionError("has_executing_nodes should only be called after execution is paused") - return self._state_manager.get_executing_count() > 0 diff --git a/api/dify_graph/graph_engine/protocols/command_channel.py b/api/dify_graph/graph_engine/protocols/command_channel.py deleted file mode 100644 index fabd8634c8..0000000000 --- a/api/dify_graph/graph_engine/protocols/command_channel.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -CommandChannel protocol for GraphEngine command communication. - -This protocol defines the interface for sending and receiving commands -to/from a GraphEngine instance, supporting both local and distributed scenarios. -""" - -from typing import Protocol - -from ..entities.commands import GraphEngineCommand - - -class CommandChannel(Protocol): - """ - Protocol for bidirectional command communication with GraphEngine. - - Since each GraphEngine instance processes only one workflow execution, - this channel is dedicated to that single execution. - """ - - def fetch_commands(self) -> list[GraphEngineCommand]: - """ - Fetch pending commands for this GraphEngine instance. - - Called by GraphEngine to poll for commands that need to be processed. - - Returns: - List of pending commands (may be empty) - """ - ... - - def send_command(self, command: GraphEngineCommand) -> None: - """ - Send a command to be processed by this GraphEngine instance. - - Called by external systems to send control commands to the running workflow. - - Args: - command: The command to send - """ - ... diff --git a/api/dify_graph/graph_engine/ready_queue/__init__.py b/api/dify_graph/graph_engine/ready_queue/__init__.py deleted file mode 100644 index acba0e961c..0000000000 --- a/api/dify_graph/graph_engine/ready_queue/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Ready queue implementations for GraphEngine. - -This package contains the protocol and implementations for managing -the queue of nodes ready for execution. -""" - -from .factory import create_ready_queue_from_state -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueue, ReadyQueueState - -__all__ = ["InMemoryReadyQueue", "ReadyQueue", "ReadyQueueState", "create_ready_queue_from_state"] diff --git a/api/dify_graph/graph_engine/ready_queue/factory.py b/api/dify_graph/graph_engine/ready_queue/factory.py deleted file mode 100644 index a9d4f470e5..0000000000 --- a/api/dify_graph/graph_engine/ready_queue/factory.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Factory for creating ReadyQueue instances from serialized state. -""" - -from __future__ import annotations - -from typing import TYPE_CHECKING - -from .in_memory import InMemoryReadyQueue -from .protocol import ReadyQueueState - -if TYPE_CHECKING: - from .protocol import ReadyQueue - - -def create_ready_queue_from_state(state: ReadyQueueState) -> ReadyQueue: - """ - Create a ReadyQueue instance from a serialized state. - - Args: - state: The serialized queue state (Pydantic model, dict, or JSON string), or None for a new empty queue - - Returns: - A ReadyQueue instance initialized with the given state - - Raises: - ValueError: If the queue type is unknown or version is unsupported - """ - if state.type == "InMemoryReadyQueue": - if state.version != "1.0": - raise ValueError(f"Unsupported InMemoryReadyQueue version: {state.version}") - queue = InMemoryReadyQueue() - # Always pass as JSON string to loads() - queue.loads(state.model_dump_json()) - return queue - else: - raise ValueError(f"Unknown ready queue type: {state.type}") diff --git a/api/dify_graph/graph_engine/ready_queue/in_memory.py b/api/dify_graph/graph_engine/ready_queue/in_memory.py deleted file mode 100644 index f2c265ece0..0000000000 --- a/api/dify_graph/graph_engine/ready_queue/in_memory.py +++ /dev/null @@ -1,140 +0,0 @@ -""" -In-memory implementation of the ReadyQueue protocol. - -This implementation wraps Python's standard queue.Queue and adds -serialization capabilities for state storage. -""" - -import queue -from typing import final - -from .protocol import ReadyQueue, ReadyQueueState - - -@final -class InMemoryReadyQueue(ReadyQueue): - """ - In-memory ready queue implementation with serialization support. - - This implementation uses Python's queue.Queue internally and provides - methods to serialize and restore the queue state. - """ - - def __init__(self, maxsize: int = 0) -> None: - """ - Initialize the in-memory ready queue. - - Args: - maxsize: Maximum size of the queue (0 for unlimited) - """ - self._queue: queue.Queue[str] = queue.Queue(maxsize=maxsize) - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - self._queue.put(item) - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - if timeout is None: - return self._queue.get(block=True) - return self._queue.get(timeout=timeout) - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - self._queue.task_done() - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - return self._queue.empty() - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - return self._queue.qsize() - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - """ - # Extract all items from the queue without removing them - items: list[str] = [] - temp_items: list[str] = [] - - # Drain the queue temporarily to get all items - while not self._queue.empty(): - try: - item = self._queue.get_nowait() - temp_items.append(item) - items.append(item) - except queue.Empty: - break - - # Put items back in the same order - for item in temp_items: - self._queue.put(item) - - state = ReadyQueueState( - type="InMemoryReadyQueue", - version="1.0", - items=items, - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - state = ReadyQueueState.model_validate_json(data) - - if state.type != "InMemoryReadyQueue": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported version: {state.version}") - - # Clear the current queue - while not self._queue.empty(): - try: - self._queue.get_nowait() - except queue.Empty: - break - - # Restore items - for item in state.items: - self._queue.put(item) diff --git a/api/dify_graph/graph_engine/ready_queue/protocol.py b/api/dify_graph/graph_engine/ready_queue/protocol.py deleted file mode 100644 index 97d3ea6dd2..0000000000 --- a/api/dify_graph/graph_engine/ready_queue/protocol.py +++ /dev/null @@ -1,104 +0,0 @@ -""" -ReadyQueue protocol for GraphEngine node execution queue. - -This protocol defines the interface for managing the queue of nodes ready -for execution, supporting both in-memory and persistent storage scenarios. -""" - -from collections.abc import Sequence -from typing import Protocol - -from pydantic import BaseModel, Field - - -class ReadyQueueState(BaseModel): - """ - Pydantic model for serialized ready queue state. - - This defines the structure of the data returned by dumps() - and expected by loads() for ready queue serialization. - """ - - type: str = Field(description="Queue implementation type (e.g., 'InMemoryReadyQueue')") - version: str = Field(description="Serialization format version") - items: Sequence[str] = Field(default_factory=list, description="List of node IDs in the queue") - - -class ReadyQueue(Protocol): - """ - Protocol for managing nodes ready for execution in GraphEngine. - - This protocol defines the interface that any ready queue implementation - must provide, enabling both in-memory queues and persistent queues - that can be serialized for state storage. - """ - - def put(self, item: str) -> None: - """ - Add a node ID to the ready queue. - - Args: - item: The node ID to add to the queue - """ - ... - - def get(self, timeout: float | None = None) -> str: - """ - Retrieve and remove a node ID from the queue. - - Args: - timeout: Maximum time to wait for an item (None for blocking) - - Returns: - The node ID retrieved from the queue - - Raises: - queue.Empty: If timeout expires and no item is available - """ - ... - - def task_done(self) -> None: - """ - Indicate that a previously retrieved task is complete. - - Used by worker threads to signal task completion for - join() synchronization. - """ - ... - - def empty(self) -> bool: - """ - Check if the queue is empty. - - Returns: - True if the queue has no items, False otherwise - """ - ... - - def qsize(self) -> int: - """ - Get the approximate size of the queue. - - Returns: - The approximate number of items in the queue - """ - ... - - def dumps(self) -> str: - """ - Serialize the queue state to a JSON string for storage. - - Returns: - A JSON string containing the serialized queue state - that can be persisted and later restored - """ - ... - - def loads(self, data: str) -> None: - """ - Restore the queue state from a JSON string. - - Args: - data: The JSON string containing the serialized queue state to restore - """ - ... diff --git a/api/dify_graph/graph_engine/response_coordinator/__init__.py b/api/dify_graph/graph_engine/response_coordinator/__init__.py deleted file mode 100644 index e11d31199c..0000000000 --- a/api/dify_graph/graph_engine/response_coordinator/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -""" -ResponseStreamCoordinator - Coordinates streaming output from response nodes - -This component manages response streaming sessions and ensures ordered streaming -of responses based on upstream node outputs and constants. -""" - -from .coordinator import ResponseStreamCoordinator - -__all__ = ["ResponseStreamCoordinator"] diff --git a/api/dify_graph/graph_engine/response_coordinator/coordinator.py b/api/dify_graph/graph_engine/response_coordinator/coordinator.py deleted file mode 100644 index 941a8a496b..0000000000 --- a/api/dify_graph/graph_engine/response_coordinator/coordinator.py +++ /dev/null @@ -1,697 +0,0 @@ -""" -Main ResponseStreamCoordinator implementation. - -This module contains the public ResponseStreamCoordinator class that manages -response streaming sessions and ensures ordered streaming of responses. -""" - -import logging -from collections import deque -from collections.abc import Sequence -from threading import RLock -from typing import Literal, TypeAlias, final -from uuid import uuid4 - -from pydantic import BaseModel, Field - -from dify_graph.enums import NodeExecutionType, NodeState -from dify_graph.graph_events import NodeRunStreamChunkEvent, NodeRunSucceededEvent -from dify_graph.nodes.base.template import TextSegment, VariableSegment -from dify_graph.runtime import VariablePool -from dify_graph.runtime.graph_runtime_state import GraphProtocol - -from .path import Path -from .session import ResponseSession - -logger = logging.getLogger(__name__) - -# Type definitions -NodeID: TypeAlias = str -EdgeID: TypeAlias = str - - -class ResponseSessionState(BaseModel): - """Serializable representation of a response session.""" - - node_id: str - index: int = Field(default=0, ge=0) - - -class StreamBufferState(BaseModel): - """Serializable representation of buffered stream chunks.""" - - selector: tuple[str, ...] - events: list[NodeRunStreamChunkEvent] = Field(default_factory=list) - - -class StreamPositionState(BaseModel): - """Serializable representation for stream read positions.""" - - selector: tuple[str, ...] - position: int = Field(default=0, ge=0) - - -class ResponseStreamCoordinatorState(BaseModel): - """Serialized snapshot of ResponseStreamCoordinator.""" - - type: Literal["ResponseStreamCoordinator"] = Field(default="ResponseStreamCoordinator") - version: str = Field(default="1.0") - response_nodes: Sequence[str] = Field(default_factory=list) - active_session: ResponseSessionState | None = None - waiting_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - pending_sessions: Sequence[ResponseSessionState] = Field(default_factory=list) - node_execution_ids: dict[str, str] = Field(default_factory=dict) - paths_map: dict[str, list[list[str]]] = Field(default_factory=dict) - stream_buffers: Sequence[StreamBufferState] = Field(default_factory=list) - stream_positions: Sequence[StreamPositionState] = Field(default_factory=list) - closed_streams: Sequence[tuple[str, ...]] = Field(default_factory=list) - - -@final -class ResponseStreamCoordinator: - """ - Manages response streaming sessions without relying on global state. - - Ensures ordered streaming of responses based on upstream node outputs and constants. - """ - - def __init__(self, variable_pool: "VariablePool", graph: GraphProtocol) -> None: - """ - Initialize coordinator with variable pool. - - Args: - variable_pool: VariablePool instance for accessing node variables - graph: Graph instance for looking up node information - """ - self._variable_pool = variable_pool - self._graph = graph - self._active_session: ResponseSession | None = None - self._waiting_sessions: deque[ResponseSession] = deque() - self._lock = RLock() - - # Internal stream management (replacing OutputRegistry) - self._stream_buffers: dict[tuple[str, ...], list[NodeRunStreamChunkEvent]] = {} - self._stream_positions: dict[tuple[str, ...], int] = {} - self._closed_streams: set[tuple[str, ...]] = set() - - # Track response nodes - self._response_nodes: set[NodeID] = set() - - # Store paths for each response node - self._paths_maps: dict[NodeID, list[Path]] = {} - - # Track node execution IDs and types for proper event forwarding - self._node_execution_ids: dict[NodeID, str] = {} # node_id -> execution_id - - # Track response sessions to ensure only one per node - self._response_sessions: dict[NodeID, ResponseSession] = {} # node_id -> session - - def register(self, response_node_id: NodeID) -> None: - with self._lock: - if response_node_id in self._response_nodes: - return - self._response_nodes.add(response_node_id) - - # Build and save paths map for this response node - paths_map = self._build_paths_map(response_node_id) - self._paths_maps[response_node_id] = paths_map - - # Create and store response session for this node - response_node = self._graph.nodes[response_node_id] - session = ResponseSession.from_node(response_node) - self._response_sessions[response_node_id] = session - - def track_node_execution(self, node_id: NodeID, execution_id: str) -> None: - """Track the execution ID for a node when it starts executing. - - Args: - node_id: The ID of the node - execution_id: The execution ID from NodeRunStartedEvent - """ - with self._lock: - self._node_execution_ids[node_id] = execution_id - - def _get_or_create_execution_id(self, node_id: NodeID) -> str: - """Get the execution ID for a node, creating one if it doesn't exist. - - Args: - node_id: The ID of the node - - Returns: - The execution ID for the node - """ - with self._lock: - if node_id not in self._node_execution_ids: - self._node_execution_ids[node_id] = str(uuid4()) - return self._node_execution_ids[node_id] - - def _build_paths_map(self, response_node_id: NodeID) -> list[Path]: - """ - Build a paths map for a response node by finding all paths from root node - to the response node, recording branch edges along each path. - - Args: - response_node_id: ID of the response node to analyze - - Returns: - List of Path objects, where each path contains branch edge IDs - """ - # Get root node ID - root_node_id = self._graph.root_node.id - - # If root is the response node, return empty path - if root_node_id == response_node_id: - return [Path()] - - # Extract variable selectors from the response node's template - response_node = self._graph.nodes[response_node_id] - response_session = ResponseSession.from_node(response_node) - template = response_session.template - - # Collect all variable selectors from the template - variable_selectors: set[tuple[str, ...]] = set() - for segment in template.segments: - if isinstance(segment, VariableSegment): - variable_selectors.add(tuple(segment.selector[:2])) - - # Step 1: Find all complete paths from root to response node - all_complete_paths: list[list[EdgeID]] = [] - - def find_paths( - current_node_id: NodeID, target_node_id: NodeID, current_path: list[EdgeID], visited: set[NodeID] - ) -> None: - """Recursively find all paths from current node to target node.""" - if current_node_id == target_node_id: - # Found a complete path, store it - all_complete_paths.append(current_path.copy()) - return - - # Mark as visited to avoid cycles - visited.add(current_node_id) - - # Explore outgoing edges - outgoing_edges = self._graph.get_outgoing_edges(current_node_id) - for edge in outgoing_edges: - edge_id = edge.id - next_node_id = edge.head - - # Skip if already visited in this path - if next_node_id not in visited: - # Add edge to path and recurse - new_path = current_path + [edge_id] - find_paths(next_node_id, target_node_id, new_path, visited.copy()) - - # Start searching from root node - find_paths(root_node_id, response_node_id, [], set()) - - # Step 2: For each complete path, filter edges based on node blocking behavior - filtered_paths: list[Path] = [] - for path in all_complete_paths: - blocking_edges: list[str] = [] - for edge_id in path: - edge = self._graph.edges[edge_id] - source_node = self._graph.nodes[edge.tail] - - # Check if node is a branch, container, or response node - if source_node.execution_type in { - NodeExecutionType.BRANCH, - NodeExecutionType.CONTAINER, - NodeExecutionType.RESPONSE, - } or source_node.blocks_variable_output(variable_selectors): - blocking_edges.append(edge_id) - - # Keep the path even if it's empty - filtered_paths.append(Path(edges=blocking_edges)) - - return filtered_paths - - def on_edge_taken(self, edge_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Handle when an edge is taken (selected by a branch node). - - This method updates the paths for all response nodes by removing - the taken edge. If any response node has an empty path after removal, - it means the node is now deterministically reachable and should start. - - Args: - edge_id: The ID of the edge that was taken - - Returns: - List of events to emit from starting new sessions - """ - events: list[NodeRunStreamChunkEvent] = [] - - with self._lock: - # Check each response node in order - for response_node_id in self._response_nodes: - if response_node_id not in self._paths_maps: - continue - - paths = self._paths_maps[response_node_id] - has_reachable_path = False - - # Update each path by removing the taken edge - for path in paths: - # Remove the taken edge from this path - path.remove_edge(edge_id) - - # Check if this path is now empty (node is reachable) - if path.is_empty(): - has_reachable_path = True - - # If node is now reachable (has empty path), start/queue session - if has_reachable_path: - # Pass the node_id to the activation method - # The method will handle checking and removing from map - events.extend(self._active_or_queue_session(response_node_id)) - return events - - def _active_or_queue_session(self, node_id: str) -> Sequence[NodeRunStreamChunkEvent]: - """ - Start a session immediately if no active session, otherwise queue it. - Only activates sessions that exist in the _response_sessions map. - - Args: - node_id: The ID of the response node to activate - - Returns: - List of events from flush attempt if session started immediately - """ - events: list[NodeRunStreamChunkEvent] = [] - - # Get the session from our map (only activate if it exists) - session = self._response_sessions.get(node_id) - if not session: - return events - - # Remove from map to ensure it won't be activated again - del self._response_sessions[node_id] - - if self._active_session is None: - self._active_session = session - - # Try to flush immediately - events.extend(self.try_flush()) - else: - # Queue the session if another is active - self._waiting_sessions.append(session) - - return events - - def intercept_event( - self, event: NodeRunStreamChunkEvent | NodeRunSucceededEvent - ) -> Sequence[NodeRunStreamChunkEvent]: - with self._lock: - if isinstance(event, NodeRunStreamChunkEvent): - self._append_stream_chunk(event.selector, event) - if event.is_final: - self._close_stream(event.selector) - return self.try_flush() - else: - # Skip cause we share the same variable pool. - # - # for variable_name, variable_value in event.node_run_result.outputs.items(): - # self._variable_pool.add((event.node_id, variable_name), variable_value) - return self.try_flush() - - def _create_stream_chunk_event( - self, - node_id: str, - execution_id: str, - selector: Sequence[str], - chunk: str, - is_final: bool = False, - ) -> NodeRunStreamChunkEvent: - """Create a stream chunk event with consistent structure. - - For selectors with special prefixes (sys, env, conversation), we use the - active response node's information since these are not actual node IDs. - """ - # Check if this is a special selector that doesn't correspond to a node - if selector and selector[0] not in self._graph.nodes and self._active_session: - # Use the active response node for special selectors - response_node = self._graph.nodes[self._active_session.node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - # Standard case: selector refers to an actual node - node = self._graph.nodes[node_id] - return NodeRunStreamChunkEvent( - id=execution_id, - node_id=node.id, - node_type=node.node_type, - selector=selector, - chunk=chunk, - is_final=is_final, - ) - - def _process_variable_segment(self, segment: VariableSegment) -> tuple[Sequence[NodeRunStreamChunkEvent], bool]: - """Process a variable segment. Returns (events, is_complete). - - Handles both regular node selectors and special system selectors (sys, env, conversation). - For special selectors, we attribute the output to the active response node. - """ - events: list[NodeRunStreamChunkEvent] = [] - source_selector_prefix = segment.selector[0] if segment.selector else "" - is_complete = False - - # Determine which node to attribute the output to - # For special selectors (sys, env, conversation), use the active response node - # For regular selectors, use the source node - if self._active_session and source_selector_prefix not in self._graph.nodes: - # Special selector - use active response node - output_node_id = self._active_session.node_id - else: - # Regular node selector - output_node_id = source_selector_prefix - execution_id = self._get_or_create_execution_id(output_node_id) - - # Stream all available chunks - while self._has_unread_stream(segment.selector): - if event := self._pop_stream_chunk(segment.selector): - # For special selectors, we need to update the event to use - # the active response node's information - if self._active_session and source_selector_prefix not in self._graph.nodes: - response_node = self._graph.nodes[self._active_session.node_id] - # Create a new event with the response node's information - # but keep the original selector - updated_event = NodeRunStreamChunkEvent( - id=execution_id, - node_id=response_node.id, - node_type=response_node.node_type, - selector=event.selector, # Keep original selector - chunk=event.chunk, - is_final=event.is_final, - ) - events.append(updated_event) - else: - # Regular node selector - use event as is - events.append(event) - - # Check if this is the last chunk by looking ahead - stream_closed = self._is_stream_closed(segment.selector) - # Check if stream is closed to determine if segment is complete - if stream_closed: - is_complete = True - - elif value := self._variable_pool.get(segment.selector): - # Process scalar value - is_last_segment = bool( - self._active_session and self._active_session.index == len(self._active_session.template.segments) - 1 - ) - events.append( - self._create_stream_chunk_event( - node_id=output_node_id, - execution_id=execution_id, - selector=segment.selector, - chunk=value.markdown, - is_final=is_last_segment, - ) - ) - is_complete = True - - return events, is_complete - - def _process_text_segment(self, segment: TextSegment) -> Sequence[NodeRunStreamChunkEvent]: - """Process a text segment. Returns (events, is_complete).""" - assert self._active_session is not None - current_response_node = self._graph.nodes[self._active_session.node_id] - - # Use get_or_create_execution_id to ensure we have a consistent ID - execution_id = self._get_or_create_execution_id(current_response_node.id) - - is_last_segment = self._active_session.index == len(self._active_session.template.segments) - 1 - event = self._create_stream_chunk_event( - node_id=current_response_node.id, - execution_id=execution_id, - selector=[current_response_node.id, "answer"], # FIXME(-LAN-) - chunk=segment.text, - is_final=is_last_segment, - ) - return [event] - - def try_flush(self) -> list[NodeRunStreamChunkEvent]: - with self._lock: - if not self._active_session: - return [] - - template = self._active_session.template - response_node_id = self._active_session.node_id - - events: list[NodeRunStreamChunkEvent] = [] - - # Process segments sequentially from current index - while self._active_session.index < len(template.segments): - segment = template.segments[self._active_session.index] - - if isinstance(segment, VariableSegment): - # Check if the source node for this variable is skipped - # Only check for actual nodes, not special selectors (sys, env, conversation) - source_selector_prefix = segment.selector[0] if segment.selector else "" - if source_selector_prefix in self._graph.nodes: - source_node = self._graph.nodes[source_selector_prefix] - - if source_node.state == NodeState.SKIPPED: - # Skip this variable segment if the source node is skipped - self._active_session.index += 1 - continue - - segment_events, is_complete = self._process_variable_segment(segment) - events.extend(segment_events) - - # Only advance index if this variable segment is complete - if is_complete: - self._active_session.index += 1 - else: - # Wait for more data - break - - else: - segment_events = self._process_text_segment(segment) - events.extend(segment_events) - self._active_session.index += 1 - - if self._active_session.is_complete(): - # End current session and get events from starting next session - next_session_events = self.end_session(response_node_id) - events.extend(next_session_events) - - return events - - def end_session(self, node_id: str) -> list[NodeRunStreamChunkEvent]: - """ - End the active session for a response node. - Automatically starts the next waiting session if available. - - Args: - node_id: ID of the response node ending its session - - Returns: - List of events from starting the next session - """ - with self._lock: - events: list[NodeRunStreamChunkEvent] = [] - - if self._active_session and self._active_session.node_id == node_id: - self._active_session = None - - # Try to start next waiting session - if self._waiting_sessions: - next_session = self._waiting_sessions.popleft() - self._active_session = next_session - - # Immediately try to flush any available segments - events = self.try_flush() - - return events - - # ============= Internal Stream Management Methods ============= - - def _append_stream_chunk(self, selector: Sequence[str], event: NodeRunStreamChunkEvent) -> None: - """ - Append a stream chunk to the internal buffer. - - Args: - selector: List of strings identifying the stream location - event: The NodeRunStreamChunkEvent to append - - Raises: - ValueError: If the stream is already closed - """ - key = tuple(selector) - - if key in self._closed_streams: - raise ValueError(f"Stream {'.'.join(selector)} is already closed") - - if key not in self._stream_buffers: - self._stream_buffers[key] = [] - self._stream_positions[key] = 0 - - self._stream_buffers[key].append(event) - - def _pop_stream_chunk(self, selector: Sequence[str]) -> NodeRunStreamChunkEvent | None: - """ - Pop the next unread stream chunk from the buffer. - - Args: - selector: List of strings identifying the stream location - - Returns: - The next event, or None if no unread events available - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return None - - position = self._stream_positions.get(key, 0) - buffer = self._stream_buffers[key] - - if position >= len(buffer): - return None - - event = buffer[position] - self._stream_positions[key] = position + 1 - return event - - def _has_unread_stream(self, selector: Sequence[str]) -> bool: - """ - Check if the stream has unread events. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if there are unread events, False otherwise - """ - key = tuple(selector) - - if key not in self._stream_buffers: - return False - - position = self._stream_positions.get(key, 0) - return position < len(self._stream_buffers[key]) - - def _close_stream(self, selector: Sequence[str]) -> None: - """ - Mark a stream as closed (no more chunks can be appended). - - Args: - selector: List of strings identifying the stream location - """ - key = tuple(selector) - self._closed_streams.add(key) - - def _is_stream_closed(self, selector: Sequence[str]) -> bool: - """ - Check if a stream is closed. - - Args: - selector: List of strings identifying the stream location - - Returns: - True if the stream is closed, False otherwise - """ - key = tuple(selector) - return key in self._closed_streams - - def _serialize_session(self, session: ResponseSession | None) -> ResponseSessionState | None: - """Convert an in-memory session into its serializable form.""" - - if session is None: - return None - return ResponseSessionState(node_id=session.node_id, index=session.index) - - def _session_from_state(self, session_state: ResponseSessionState) -> ResponseSession: - """Rebuild a response session from serialized data.""" - - node = self._graph.nodes.get(session_state.node_id) - if node is None: - raise ValueError(f"Unknown response node '{session_state.node_id}' in serialized state") - - session = ResponseSession.from_node(node) - session.index = session_state.index - return session - - def dumps(self) -> str: - """Serialize coordinator state to JSON.""" - - with self._lock: - state = ResponseStreamCoordinatorState( - response_nodes=sorted(self._response_nodes), - active_session=self._serialize_session(self._active_session), - waiting_sessions=[ - session_state - for session in list(self._waiting_sessions) - if (session_state := self._serialize_session(session)) is not None - ], - pending_sessions=[ - session_state - for _, session in sorted(self._response_sessions.items()) - if (session_state := self._serialize_session(session)) is not None - ], - node_execution_ids=dict(sorted(self._node_execution_ids.items())), - paths_map={ - node_id: [path.edges.copy() for path in paths] - for node_id, paths in sorted(self._paths_maps.items()) - }, - stream_buffers=[ - StreamBufferState( - selector=selector, - events=[event.model_copy(deep=True) for event in events], - ) - for selector, events in sorted(self._stream_buffers.items()) - ], - stream_positions=[ - StreamPositionState(selector=selector, position=position) - for selector, position in sorted(self._stream_positions.items()) - ], - closed_streams=sorted(self._closed_streams), - ) - return state.model_dump_json() - - def loads(self, data: str) -> None: - """Restore coordinator state from JSON.""" - - state = ResponseStreamCoordinatorState.model_validate_json(data) - - if state.type != "ResponseStreamCoordinator": - raise ValueError(f"Invalid serialized data type: {state.type}") - - if state.version != "1.0": - raise ValueError(f"Unsupported serialized version: {state.version}") - - with self._lock: - self._response_nodes = set(state.response_nodes) - self._paths_maps = { - node_id: [Path(edges=list(path_edges)) for path_edges in paths] - for node_id, paths in state.paths_map.items() - } - self._node_execution_ids = dict(state.node_execution_ids) - - self._stream_buffers = { - tuple(buffer.selector): [event.model_copy(deep=True) for event in buffer.events] - for buffer in state.stream_buffers - } - self._stream_positions = { - tuple(position.selector): position.position for position in state.stream_positions - } - for selector in self._stream_buffers: - self._stream_positions.setdefault(selector, 0) - - self._closed_streams = {tuple(selector) for selector in state.closed_streams} - - self._waiting_sessions = deque( - self._session_from_state(session_state) for session_state in state.waiting_sessions - ) - self._response_sessions = { - session_state.node_id: self._session_from_state(session_state) - for session_state in state.pending_sessions - } - self._active_session = self._session_from_state(state.active_session) if state.active_session else None diff --git a/api/dify_graph/graph_engine/response_coordinator/path.py b/api/dify_graph/graph_engine/response_coordinator/path.py deleted file mode 100644 index 50f2f4eb21..0000000000 --- a/api/dify_graph/graph_engine/response_coordinator/path.py +++ /dev/null @@ -1,35 +0,0 @@ -""" -Internal path representation for response coordinator. - -This module contains the private Path class used internally by ResponseStreamCoordinator -to track execution paths to response nodes. -""" - -from dataclasses import dataclass, field -from typing import TypeAlias - -EdgeID: TypeAlias = str - - -@dataclass -class Path: - """ - Represents a path of branch edges that must be taken to reach a response node. - - Note: This is an internal class not exposed in the public API. - """ - - edges: list[EdgeID] = field(default_factory=list[EdgeID]) - - def contains_edge(self, edge_id: EdgeID) -> bool: - """Check if this path contains the given edge.""" - return edge_id in self.edges - - def remove_edge(self, edge_id: EdgeID) -> None: - """Remove the given edge from this path in place.""" - if self.contains_edge(edge_id): - self.edges.remove(edge_id) - - def is_empty(self) -> bool: - """Check if the path has no edges (node is reachable).""" - return len(self.edges) == 0 diff --git a/api/dify_graph/graph_engine/response_coordinator/session.py b/api/dify_graph/graph_engine/response_coordinator/session.py deleted file mode 100644 index 11a9f5dac5..0000000000 --- a/api/dify_graph/graph_engine/response_coordinator/session.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -Internal response session management for response coordinator. - -This module contains the private ResponseSession class used internally -by ResponseStreamCoordinator to manage streaming sessions. -""" - -from __future__ import annotations - -from dataclasses import dataclass -from typing import Protocol, cast - -from dify_graph.nodes.base.template import Template -from dify_graph.runtime.graph_runtime_state import NodeProtocol - - -class _ResponseSessionNodeProtocol(NodeProtocol, Protocol): - """Structural contract required from nodes that can open a response session.""" - - def get_streaming_template(self) -> Template: ... - - -@dataclass -class ResponseSession: - """ - Represents an active response streaming session. - - Note: This is an internal class not exposed in the public API. - """ - - node_id: str - template: Template # Template object from the response node - index: int = 0 # Current position in the template segments - - @classmethod - def from_node(cls, node: NodeProtocol) -> ResponseSession: - """ - Create a ResponseSession from a response-capable node. - - The parameter is typed as `NodeProtocol` because the graph is exposed behind a protocol at the runtime layer. - At runtime this must be a node that implements `get_streaming_template()`. The coordinator decides which - graph nodes should be treated as response-capable before they reach this factory. - - Args: - node: Node from the materialized workflow graph. - - Returns: - ResponseSession configured with the node's streaming template - - Raises: - TypeError: If node does not implement the response-session streaming contract. - """ - response_node = cast(_ResponseSessionNodeProtocol, node) - try: - template = response_node.get_streaming_template() - except AttributeError as exc: - raise TypeError("ResponseSession.from_node requires get_streaming_template() on response nodes") from exc - - return cls( - node_id=node.id, - template=template, - ) - - def is_complete(self) -> bool: - """Check if all segments in the template have been processed.""" - return self.index >= len(self.template.segments) diff --git a/api/dify_graph/graph_engine/worker.py b/api/dify_graph/graph_engine/worker.py deleted file mode 100644 index 988c20d72a..0000000000 --- a/api/dify_graph/graph_engine/worker.py +++ /dev/null @@ -1,205 +0,0 @@ -""" -Worker - Thread implementation for queue-based node execution - -Workers pull node IDs from the ready_queue, execute nodes, and push events -to the event_queue for the dispatcher to process. -""" - -import queue -import threading -import time -from collections.abc import Sequence -from datetime import datetime -from typing import TYPE_CHECKING, final - -from typing_extensions import override - -from dify_graph.context import IExecutionContext -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunStartedEvent, is_node_result_event -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from libs.datetime_utils import naive_utc_now - -from .ready_queue import ReadyQueue - -if TYPE_CHECKING: - pass - - -@final -class Worker(threading.Thread): - """ - Worker thread that executes nodes from the ready queue. - - Workers continuously pull node IDs from the ready_queue, execute the - corresponding nodes, and push the resulting events to the event_queue - for the dispatcher to process. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: Sequence[GraphEngineLayer], - worker_id: int = 0, - execution_context: IExecutionContext | None = None, - ) -> None: - """ - Initialize worker thread. - - Args: - ready_queue: Ready queue containing node IDs ready for execution - event_queue: Queue for pushing execution events - graph: Graph containing nodes to execute - layers: Graph engine layers for node execution hooks - worker_id: Unique identifier for this worker - execution_context: Optional execution context for context preservation - """ - super().__init__(name=f"GraphWorker-{worker_id}", daemon=True) - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._worker_id = worker_id - self._execution_context = execution_context - self._stop_event = threading.Event() - self._layers = layers if layers is not None else [] - self._last_task_time = time.time() - self._current_node_started_at: datetime | None = None - - def stop(self) -> None: - """Signal the worker to stop processing.""" - self._stop_event.set() - - @property - def is_idle(self) -> bool: - """Check if the worker is currently idle.""" - # Worker is idle if it hasn't processed a task recently (within 0.2 seconds) - return (time.time() - self._last_task_time) > 0.2 - - @property - def idle_duration(self) -> float: - """Get the duration in seconds since the worker last processed a task.""" - return time.time() - self._last_task_time - - @property - def worker_id(self) -> int: - """Get the worker's ID.""" - return self._worker_id - - @override - def run(self) -> None: - """ - Main worker loop. - - Continuously pulls node IDs from ready_queue, executes them, - and pushes events to event_queue until stopped. - """ - while not self._stop_event.is_set(): - # Try to get a node ID from the ready queue (with timeout) - try: - node_id = self._ready_queue.get(timeout=0.1) - except queue.Empty: - continue - - self._last_task_time = time.time() - node = self._graph.nodes[node_id] - try: - self._current_node_started_at = None - self._execute_node(node) - self._ready_queue.task_done() - except Exception as e: - self._event_queue.put( - self._build_fallback_failure_event(node, e, started_at=self._current_node_started_at) - ) - finally: - self._current_node_started_at = None - - def _execute_node(self, node: Node) -> None: - """ - Execute a single node and handle its events. - - Args: - node: The node instance to execute - """ - node.ensure_execution_id() - - error: Exception | None = None - result_event: GraphNodeEventBase | None = None - - # Execute the node with preserved context if execution context is provided - if self._execution_context is not None: - with self._execution_context: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - else: - self._invoke_node_run_start_hooks(node) - try: - node_events = node.run() - for event in node_events: - if isinstance(event, NodeRunStartedEvent) and event.id == node.execution_id: - self._current_node_started_at = event.start_at - self._event_queue.put(event) - if is_node_result_event(event): - result_event = event - except Exception as exc: - error = exc - raise - finally: - self._invoke_node_run_end_hooks(node, error, result_event) - - def _invoke_node_run_start_hooks(self, node: Node) -> None: - """Invoke on_node_run_start hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_start(node) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _invoke_node_run_end_hooks( - self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None - ) -> None: - """Invoke on_node_run_end hooks for all layers.""" - for layer in self._layers: - try: - layer.on_node_run_end(node, error, result_event) - except Exception: - # Silently ignore layer errors to prevent disrupting node execution - continue - - def _build_fallback_failure_event( - self, node: Node, error: Exception, *, started_at: datetime | None = None - ) -> NodeRunFailedEvent: - """Build a failed event when worker-level execution aborts before a node emits its own result event.""" - failure_time = naive_utc_now() - error_message = str(error) - return NodeRunFailedEvent( - id=node.execution_id, - node_id=node.id, - node_type=node.node_type, - in_iteration_id=None, - error=error_message, - start_at=started_at or failure_time, - finished_at=failure_time, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=error_message, - error_type=type(error).__name__, - ), - ) diff --git a/api/dify_graph/graph_engine/worker_management/__init__.py b/api/dify_graph/graph_engine/worker_management/__init__.py deleted file mode 100644 index 03de1f6daa..0000000000 --- a/api/dify_graph/graph_engine/worker_management/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -Worker management subsystem for graph engine. - -This package manages the worker pool, including creation, -scaling, and activity tracking. -""" - -from .worker_pool import WorkerPool - -__all__ = [ - "WorkerPool", -] diff --git a/api/dify_graph/graph_engine/worker_management/worker_pool.py b/api/dify_graph/graph_engine/worker_management/worker_pool.py deleted file mode 100644 index cc93087783..0000000000 --- a/api/dify_graph/graph_engine/worker_management/worker_pool.py +++ /dev/null @@ -1,277 +0,0 @@ -""" -Simple worker pool that consolidates functionality. - -This is a simpler implementation that merges WorkerPool, ActivityTracker, -DynamicScaler, and WorkerFactory into a single class. -""" - -import logging -import queue -import threading -from typing import final - -from dify_graph.context import IExecutionContext -from dify_graph.graph import Graph -from dify_graph.graph_events import GraphNodeEventBase - -from ..config import GraphEngineConfig -from ..layers.base import GraphEngineLayer -from ..ready_queue import ReadyQueue -from ..worker import Worker - -logger = logging.getLogger(__name__) - - -@final -class WorkerPool: - """ - Simple worker pool with integrated management. - - This class consolidates all worker management functionality into - a single, simpler implementation without excessive abstraction. - """ - - def __init__( - self, - ready_queue: ReadyQueue, - event_queue: queue.Queue[GraphNodeEventBase], - graph: Graph, - layers: list[GraphEngineLayer], - config: GraphEngineConfig, - execution_context: IExecutionContext | None = None, - ) -> None: - """ - Initialize the simple worker pool. - - Args: - ready_queue: Ready queue for nodes ready for execution - event_queue: Queue for worker events - graph: The workflow graph - layers: Graph engine layers for node execution hooks - config: GraphEngine worker pool configuration - execution_context: Optional execution context for context preservation - """ - self._ready_queue = ready_queue - self._event_queue = event_queue - self._graph = graph - self._execution_context = execution_context - self._layers = layers - self._config = config - - # Worker management - self._workers: list[Worker] = [] - self._worker_counter = 0 - self._lock = threading.RLock() - self._running = False - - # No longer tracking worker states with callbacks to avoid lock contention - - def start(self, initial_count: int | None = None) -> None: - """ - Start the worker pool. - - Args: - initial_count: Number of workers to start with (auto-calculated if None) - """ - with self._lock: - if self._running: - return - - self._running = True - - # Calculate initial worker count - if initial_count is None: - node_count = len(self._graph.nodes) - if node_count < 10: - initial_count = self._config.min_workers - elif node_count < 50: - initial_count = min(self._config.min_workers + 1, self._config.max_workers) - else: - initial_count = min(self._config.min_workers + 2, self._config.max_workers) - - logger.debug( - "Starting worker pool: %d workers (nodes=%d, min=%d, max=%d)", - initial_count, - node_count, - self._config.min_workers, - self._config.max_workers, - ) - - # Create initial workers - for _ in range(initial_count): - self._create_worker() - - def stop(self) -> None: - """Stop all workers in the pool.""" - with self._lock: - self._running = False - worker_count = len(self._workers) - - if worker_count > 0: - logger.debug("Stopping worker pool: %d workers", worker_count) - - # Stop all workers - for worker in self._workers: - worker.stop() - - # Wait for workers to finish - for worker in self._workers: - if worker.is_alive(): - worker.join(timeout=2.0) - - self._workers.clear() - - def _create_worker(self) -> None: - """Create and start a new worker.""" - worker_id = self._worker_counter - self._worker_counter += 1 - - worker = Worker( - ready_queue=self._ready_queue, - event_queue=self._event_queue, - graph=self._graph, - layers=self._layers, - worker_id=worker_id, - execution_context=self._execution_context, - ) - - worker.start() - self._workers.append(worker) - - def _remove_worker(self, worker: Worker, worker_id: int) -> None: - """Remove a specific worker from the pool.""" - # Stop the worker - worker.stop() - - # Wait for it to finish - if worker.is_alive(): - worker.join(timeout=2.0) - - # Remove from list - if worker in self._workers: - self._workers.remove(worker) - - def _try_scale_up(self, queue_depth: int, current_count: int) -> bool: - """ - Try to scale up workers if needed. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - - Returns: - True if scaled up, False otherwise - """ - if queue_depth > self._config.scale_up_threshold and current_count < self._config.max_workers: - old_count = current_count - self._create_worker() - - logger.debug( - "Scaled up workers: %d -> %d (queue_depth=%d exceeded threshold=%d)", - old_count, - len(self._workers), - queue_depth, - self._config.scale_up_threshold, - ) - return True - return False - - def _try_scale_down(self, queue_depth: int, current_count: int, active_count: int, idle_count: int) -> bool: - """ - Try to scale down workers if we have excess capacity. - - Args: - queue_depth: Current queue depth - current_count: Current number of workers - active_count: Number of active workers - idle_count: Number of idle workers - - Returns: - True if scaled down, False otherwise - """ - # Skip if we're at minimum or have no idle workers - if current_count <= self._config.min_workers or idle_count == 0: - return False - - # Check if we have excess capacity - has_excess_capacity = ( - queue_depth <= active_count # Active workers can handle current queue - or idle_count > active_count # More idle than active workers - or (queue_depth == 0 and idle_count > 0) # No work and have idle workers - ) - - if not has_excess_capacity: - return False - - # Find and remove idle workers that have been idle long enough - workers_to_remove: list[tuple[Worker, int]] = [] - - for worker in self._workers: - # Check if worker is idle and has exceeded idle time threshold - if worker.is_idle and worker.idle_duration >= self._config.scale_down_idle_time: - # Don't remove if it would leave us unable to handle the queue - remaining_workers = current_count - len(workers_to_remove) - 1 - if remaining_workers >= self._config.min_workers and remaining_workers >= max(1, queue_depth // 2): - workers_to_remove.append((worker, worker.worker_id)) - # Only remove one worker per check to avoid aggressive scaling - break - - # Remove idle workers if any found - if workers_to_remove: - old_count = current_count - for worker, worker_id in workers_to_remove: - self._remove_worker(worker, worker_id) - - logger.debug( - "Scaled down workers: %d -> %d (removed %d idle workers after %.1fs, " - "queue_depth=%d, active=%d, idle=%d)", - old_count, - len(self._workers), - len(workers_to_remove), - self._config.scale_down_idle_time, - queue_depth, - active_count, - idle_count - len(workers_to_remove), - ) - return True - - return False - - def check_and_scale(self) -> None: - """Check and perform scaling if needed.""" - with self._lock: - if not self._running: - return - - current_count = len(self._workers) - queue_depth = self._ready_queue.qsize() - - # Count active vs idle workers by querying their state directly - idle_count = sum(1 for worker in self._workers if worker.is_idle) - active_count = current_count - idle_count - - # Try to scale up if queue is backing up - self._try_scale_up(queue_depth, current_count) - - # Try to scale down if we have excess capacity - self._try_scale_down(queue_depth, current_count, active_count, idle_count) - - def get_worker_count(self) -> int: - """Get current number of workers.""" - with self._lock: - return len(self._workers) - - def get_status(self) -> dict[str, int]: - """ - Get pool status information. - - Returns: - Dictionary with status information - """ - with self._lock: - return { - "total_workers": len(self._workers), - "queue_depth": self._ready_queue.qsize(), - "min_workers": self._config.min_workers, - "max_workers": self._config.max_workers, - } diff --git a/api/dify_graph/graph_events/__init__.py b/api/dify_graph/graph_events/__init__.py deleted file mode 100644 index 56ea642092..0000000000 --- a/api/dify_graph/graph_events/__init__.py +++ /dev/null @@ -1,82 +0,0 @@ -# Agent events -from .agent import NodeRunAgentLogEvent - -# Base events -from .base import ( - BaseGraphEvent, - GraphEngineEvent, - GraphNodeEventBase, -) - -# Graph events -from .graph import ( - GraphRunAbortedEvent, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Iteration events -from .iteration import ( - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, -) - -# Loop events -from .loop import ( - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, -) - -# Node events -from .node import ( - NodeRunExceptionEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunRetryEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - is_node_result_event, -) - -__all__ = [ - "BaseGraphEvent", - "GraphEngineEvent", - "GraphNodeEventBase", - "GraphRunAbortedEvent", - "GraphRunFailedEvent", - "GraphRunPartialSucceededEvent", - "GraphRunPausedEvent", - "GraphRunStartedEvent", - "GraphRunSucceededEvent", - "NodeRunAgentLogEvent", - "NodeRunExceptionEvent", - "NodeRunFailedEvent", - "NodeRunHumanInputFormFilledEvent", - "NodeRunHumanInputFormTimeoutEvent", - "NodeRunIterationFailedEvent", - "NodeRunIterationNextEvent", - "NodeRunIterationStartedEvent", - "NodeRunIterationSucceededEvent", - "NodeRunLoopFailedEvent", - "NodeRunLoopNextEvent", - "NodeRunLoopStartedEvent", - "NodeRunLoopSucceededEvent", - "NodeRunPauseRequestedEvent", - "NodeRunRetrieverResourceEvent", - "NodeRunRetryEvent", - "NodeRunStartedEvent", - "NodeRunStreamChunkEvent", - "NodeRunSucceededEvent", - "is_node_result_event", -] diff --git a/api/dify_graph/graph_events/agent.py b/api/dify_graph/graph_events/agent.py deleted file mode 100644 index 759fe3a71c..0000000000 --- a/api/dify_graph/graph_events/agent.py +++ /dev/null @@ -1,17 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import GraphAgentNodeEventBase - - -class NodeRunAgentLogEvent(GraphAgentNodeEventBase): - message_id: str = Field(..., description="message id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, object] = Field(default_factory=dict) diff --git a/api/dify_graph/graph_events/base.py b/api/dify_graph/graph_events/base.py deleted file mode 100644 index 4560cf5085..0000000000 --- a/api/dify_graph/graph_events/base.py +++ /dev/null @@ -1,31 +0,0 @@ -from pydantic import BaseModel, Field - -from dify_graph.enums import NodeType -from dify_graph.node_events import NodeRunResult - - -class GraphEngineEvent(BaseModel): - pass - - -class BaseGraphEvent(GraphEngineEvent): - pass - - -class GraphNodeEventBase(GraphEngineEvent): - id: str = Field(..., description="node execution id") - node_id: str - node_type: NodeType - - in_iteration_id: str | None = None - """iteration id if node is in iteration""" - in_loop_id: str | None = None - """loop id if node is in loop""" - - # The version of the node, or "1" if not specified. - node_version: str = "1" - node_run_result: NodeRunResult = Field(default_factory=NodeRunResult) - - -class GraphAgentNodeEventBase(GraphNodeEventBase): - pass diff --git a/api/dify_graph/graph_events/graph.py b/api/dify_graph/graph_events/graph.py deleted file mode 100644 index f4aaba64d6..0000000000 --- a/api/dify_graph/graph_events/graph.py +++ /dev/null @@ -1,57 +0,0 @@ -from pydantic import Field - -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph_events import BaseGraphEvent - - -class GraphRunStartedEvent(BaseGraphEvent): - # Reason is emitted for workflow start events and is always set. - reason: WorkflowStartReason = Field( - default=WorkflowStartReason.INITIAL, - description="reason for workflow start", - ) - - -class GraphRunSucceededEvent(BaseGraphEvent): - """Event emitted when a run completes successfully with final outputs.""" - - outputs: dict[str, object] = Field( - default_factory=dict, - description="Final workflow outputs keyed by output selector.", - ) - - -class GraphRunFailedEvent(BaseGraphEvent): - error: str = Field(..., description="failed reason") - exceptions_count: int = Field(description="exception count", default=0) - - -class GraphRunPartialSucceededEvent(BaseGraphEvent): - """Event emitted when a run finishes with partial success and failures.""" - - exceptions_count: int = Field(..., description="exception count") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs that were materialised before failures occurred.", - ) - - -class GraphRunAbortedEvent(BaseGraphEvent): - """Event emitted when a graph run is aborted by user command.""" - - reason: str | None = Field(default=None, description="reason for abort") - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs produced before the abort was requested.", - ) - - -class GraphRunPausedEvent(BaseGraphEvent): - """Event emitted when a graph run is paused by user command.""" - - reasons: list[PauseReason] = Field(description="reason for pause", default_factory=list) - outputs: dict[str, object] = Field( - default_factory=dict, - description="Outputs available to the client while the run is paused.", - ) diff --git a/api/dify_graph/graph_events/human_input.py b/api/dify_graph/graph_events/human_input.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/graph_events/iteration.py b/api/dify_graph/graph_events/iteration.py deleted file mode 100644 index 28627395fd..0000000000 --- a/api/dify_graph/graph_events/iteration.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunIterationStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunIterationNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class NodeRunIterationSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunIterationFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/dify_graph/graph_events/loop.py b/api/dify_graph/graph_events/loop.py deleted file mode 100644 index 7cdc5427e2..0000000000 --- a/api/dify_graph/graph_events/loop.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import GraphNodeEventBase - - -class NodeRunLoopStartedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class NodeRunLoopNextEvent(GraphNodeEventBase): - node_title: str - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class NodeRunLoopSucceededEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class NodeRunLoopFailedEvent(GraphNodeEventBase): - node_title: str - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/dify_graph/graph_events/node.py b/api/dify_graph/graph_events/node.py deleted file mode 100644 index df19d6c03b..0000000000 --- a/api/dify_graph/graph_events/node.py +++ /dev/null @@ -1,99 +0,0 @@ -from collections.abc import Sequence -from datetime import datetime - -from pydantic import Field - -from core.rag.entities.citation_metadata import RetrievalSourceMetadata -from dify_graph.entities.pause_reason import PauseReason - -from .base import GraphNodeEventBase - - -class NodeRunStartedEvent(GraphNodeEventBase): - node_title: str - predecessor_node_id: str | None = None - start_at: datetime = Field(..., description="node start time") - extras: dict[str, object] = Field(default_factory=dict) - - # FIXME(-LAN-): only for ToolNode - provider_type: str = "" - provider_id: str = "" - - -class NodeRunStreamChunkEvent(GraphNodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class NodeRunRetrieverResourceEvent(GraphNodeEventBase): - retriever_resources: Sequence[RetrievalSourceMetadata] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - - -class NodeRunSucceededEvent(GraphNodeEventBase): - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunFailedEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunExceptionEvent(GraphNodeEventBase): - error: str = Field(..., description="error") - start_at: datetime = Field(..., description="node start time") - finished_at: datetime | None = Field(default=None, description="node finish time") - - -class NodeRunRetryEvent(NodeRunStartedEvent): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="which retry attempt is about to be performed") - - -class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase): - """Emitted when a HumanInput form is submitted and before the node finishes.""" - - node_title: str = Field(..., description="HumanInput node title") - rendered_content: str = Field(..., description="Markdown content rendered with user inputs.") - action_id: str = Field(..., description="User action identifier chosen in the form.") - action_text: str = Field(..., description="Display text of the chosen action button.") - - -class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase): - """Emitted when a HumanInput form times out.""" - - node_title: str = Field(..., description="HumanInput node title") - expiration_time: datetime = Field(..., description="Form expiration time") - - -class NodeRunPauseRequestedEvent(GraphNodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -def is_node_result_event(event: GraphNodeEventBase) -> bool: - """ - Check if an event is a final result event from node execution. - - A result event indicates the completion of a node execution and contains - runtime information such as inputs, outputs, or error details. - - Args: - event: The event to check - - Returns: - True if the event is a node result event (succeeded/failed/paused), False otherwise - """ - return isinstance( - event, - ( - NodeRunSucceededEvent, - NodeRunFailedEvent, - NodeRunPauseRequestedEvent, - ), - ) diff --git a/api/dify_graph/model_runtime/README.md b/api/dify_graph/model_runtime/README.md deleted file mode 100644 index b9d2c55210..0000000000 --- a/api/dify_graph/model_runtime/README.md +++ /dev/null @@ -1,51 +0,0 @@ -# Model Runtime - -This module provides the interface for invoking and authenticating various models, and offers Dify a unified information and credentials form rule for model providers. - -- On one hand, it decouples models from upstream and downstream processes, facilitating horizontal expansion for developers, -- On the other hand, it allows for direct display of providers and models in the frontend interface by simply defining them in the backend, eliminating the need to modify frontend logic. - -## Features - -- Supports capability invocation for 6 types of models - - - `LLM` - LLM text completion, dialogue, pre-computed tokens capability - - `Text Embedding Model` - Text Embedding, pre-computed tokens capability - - `Rerank Model` - Segment Rerank capability - - `Speech-to-text Model` - Speech to text capability - - `Text-to-speech Model` - Text to speech capability - - `Moderation` - Moderation capability - -- Model provider display - - Displays a list of all supported providers, including provider names, icons, supported model types list, predefined model list, configuration method, and credentials form rules, etc. - -- Selectable model list display - - After configuring provider/model credentials, the dropdown (application orchestration interface/default model) allows viewing of the available LLM list. Greyed out items represent predefined model lists from providers without configured credentials, facilitating user review of supported models. - - In addition, this list also returns configurable parameter information and rules for LLM. These parameters are all defined in the backend, allowing different settings for various parameters supported by different models. - -- Provider/model credential authentication - - The provider list returns configuration information for the credentials form, which can be authenticated through Runtime's interface. - -## Structure - -Model Runtime is divided into three layers: - -- The outermost layer is the factory method - - It provides methods for obtaining all providers, all model lists, getting provider instances, and authenticating provider/model credentials. - -- The second layer is the provider layer - - It provides the current provider's model list, model instance obtaining, provider credential authentication, and provider configuration rule information, **allowing horizontal expansion** to support different providers. - -- The bottom layer is the model layer - - It offers direct invocation of various model types, predefined model configuration information, getting predefined/remote model lists, model credential authentication methods. Different models provide additional special methods, like LLM's pre-computed tokens method, cost information obtaining method, etc., **allowing horizontal expansion** for different models under the same provider (within supported model types). - -## Documentation - -For detailed documentation on how to add new providers or models, please refer to the [Dify documentation](https://docs.dify.ai/). diff --git a/api/dify_graph/model_runtime/README_CN.md b/api/dify_graph/model_runtime/README_CN.md deleted file mode 100644 index 0a8b56b3fe..0000000000 --- a/api/dify_graph/model_runtime/README_CN.md +++ /dev/null @@ -1,64 +0,0 @@ -# Model Runtime - -该模块提供了各模型的调用、鉴权接口,并为 Dify 提供了统一的模型供应商的信息和凭据表单规则。 - -- 一方面将模型和上下游解耦,方便开发者对模型横向扩展, -- 另一方面提供了只需在后端定义供应商和模型,即可在前端页面直接展示,无需修改前端逻辑。 - -## 功能介绍 - -- 支持 6 种模型类型的能力调用 - - - `LLM` - LLM 文本补全、对话,预计算 tokens 能力 - - `Text Embedding Model` - 文本 Embedding,预计算 tokens 能力 - - `Rerank Model` - 分段 Rerank 能力 - - `Speech-to-text Model` - 语音转文本能力 - - `Text-to-speech Model` - 文本转语音能力 - - `Moderation` - Moderation 能力 - -- 模型供应商展示 - - 展示所有已支持的供应商列表,除了返回供应商名称、图标之外,还提供了支持的模型类型列表,预定义模型列表、配置方式以及配置凭据的表单规则等等。 - -- 可选择的模型列表展示 - - 配置供应商/模型凭据后,可在此下拉(应用编排界面/默认模型)查看可用的 LLM 列表,其中灰色的为未配置凭据供应商的预定义模型列表,方便用户查看已支持的模型。 - - 除此之外,该列表还返回了 LLM 可配置的参数信息和规则。这里的参数均为后端定义,相比之前只有 5 种固定参数,这里可为不同模型设置所支持的各种参数。 - -- 供应商/模型凭据鉴权 - - 供应商列表返回了凭据表单的配置信息,可通过 Runtime 提供的接口对凭据进行鉴权。 - -## 结构 - -Model Runtime 分三层: - -- 最外层为工厂方法 - - 提供获取所有供应商、所有模型列表、获取供应商实例、供应商/模型凭据鉴权方法。 - -- 第二层为供应商层 - - 提供获取当前供应商模型列表、获取模型实例、供应商凭据鉴权、供应商配置规则信息,**可横向扩展**以支持不同的供应商。 - - 对于供应商/模型凭据,有两种情况 - - - 如 OpenAI 这类中心化供应商,需要定义如**api_key**这类的鉴权凭据 - - 如[**Xinference**](https://github.com/xorbitsai/inference)这类本地部署的供应商,需要定义如**server_url**这类的地址凭据,有时候还需要定义**model_uid**之类的模型类型凭据。当在供应商层定义了这些凭据后,就可以在前端页面上直接展示,无需修改前端逻辑。 - - 当配置好凭据后,就可以通过 DifyRuntime 的外部接口直接获取到对应供应商所需要的**Schema**(凭据表单规则),从而在可以在不修改前端逻辑的情况下,提供新的供应商/模型的支持。 - -- 最底层为模型层 - - 提供各种模型类型的直接调用、预定义模型配置信息、获取预定义/远程模型列表、模型凭据鉴权方法,不同模型额外提供了特殊方法,如 LLM 提供预计算 tokens 方法、获取费用信息方法等,**可横向扩展**同供应商下不同的模型(支持的模型类型下)。 - - 在这里我们需要先区分模型参数与模型凭据。 - - - 模型参数 (**在本层定义**):这是一类经常需要变动,随时调整的参数,如 LLM 的 **max_tokens**、**temperature** 等,这些参数是由用户在前端页面上进行调整的,因此需要在后端定义参数的规则,以便前端页面进行展示和调整。在 DifyRuntime 中,他们的参数名一般为**model_parameters: dict[str, any]**。 - - - 模型凭据 (**在供应商层定义**):这是一类不经常变动,一般在配置好后就不会再变动的参数,如 **api_key**、**server_url** 等。在 DifyRuntime 中,他们的参数名一般为**credentials: dict[str, any]**,Provider 层的 credentials 会直接被传递到这一层,不需要再单独定义。 - -## 文档 - -有关如何添加新供应商或模型的详细文档,请参阅 [Dify 文档](https://docs.dify.ai/)。 diff --git a/api/dify_graph/model_runtime/callbacks/base_callback.py b/api/dify_graph/model_runtime/callbacks/base_callback.py deleted file mode 100644 index 20faf3d6cd..0000000000 --- a/api/dify_graph/model_runtime/callbacks/base_callback.py +++ /dev/null @@ -1,151 +0,0 @@ -from abc import ABC, abstractmethod -from collections.abc import Sequence - -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - -_TEXT_COLOR_MAPPING = { - "blue": "36;1", - "yellow": "33;1", - "pink": "38;5;200", - "green": "32;1", - "red": "31;1", -} - - -class Callback(ABC): - """ - Base class for callbacks. - Only for LLM. - """ - - raise_error: bool = False - - @abstractmethod - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - raise NotImplementedError() - - @abstractmethod - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - raise NotImplementedError() - - @abstractmethod - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - raise NotImplementedError() - - @abstractmethod - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - raise NotImplementedError() - - def print_text(self, text: str, color: str | None = None, end: str = ""): - """Print text with highlighting and no end characters.""" - text_to_print = self._get_colored_text(text, color) if color else text - print(text_to_print, end=end) - - def _get_colored_text(self, text: str, color: str) -> str: - """Get colored text.""" - color_str = _TEXT_COLOR_MAPPING[color] - return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m" diff --git a/api/dify_graph/model_runtime/callbacks/logging_callback.py b/api/dify_graph/model_runtime/callbacks/logging_callback.py deleted file mode 100644 index 49b9ab27eb..0000000000 --- a/api/dify_graph/model_runtime/callbacks/logging_callback.py +++ /dev/null @@ -1,170 +0,0 @@ -import json -import logging -import sys -from collections.abc import Sequence -from typing import cast - -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk -from dify_graph.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class LoggingCallback(Callback): - def on_before_invoke( - self, - llm_instance: AIModel, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - Before invoke callback - - :param llm_instance: LLM instance - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.print_text("\n[on_llm_before_invoke]\n", color="blue") - self.print_text(f"Model: {model}\n", color="blue") - self.print_text("Parameters:\n", color="blue") - for key, value in model_parameters.items(): - self.print_text(f"\t{key}: {value}\n", color="blue") - - if stop: - self.print_text(f"\tstop: {stop}\n", color="blue") - - if tools: - self.print_text("\tTools:\n", color="blue") - for tool in tools: - self.print_text(f"\t\t{tool.name}\n", color="blue") - - self.print_text(f"Stream: {stream}\n", color="blue") - - if user: - self.print_text(f"User: {user}\n", color="blue") - - self.print_text("Prompt messages:\n", color="blue") - for prompt_message in prompt_messages: - if prompt_message.name: - self.print_text(f"\tname: {prompt_message.name}\n", color="blue") - - self.print_text(f"\trole: {prompt_message.role.value}\n", color="blue") - self.print_text(f"\tcontent: {prompt_message.content}\n", color="blue") - - if stream: - self.print_text("\n[on_llm_new_chunk]") - - def on_new_chunk( - self, - llm_instance: AIModel, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - On new chunk callback - - :param llm_instance: LLM instance - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - sys.stdout.write(cast(str, chunk.delta.message.content)) - sys.stdout.flush() - - def on_after_invoke( - self, - llm_instance: AIModel, - result: LLMResult, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - After invoke callback - - :param llm_instance: LLM instance - :param result: result - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.print_text("\n[on_llm_after_invoke]\n", color="yellow") - self.print_text(f"Content: {result.message.content}\n", color="yellow") - - if result.message.tool_calls: - self.print_text("Tool calls:\n", color="yellow") - for tool_call in result.message.tool_calls: - self.print_text(f"\t{tool_call.id}\n", color="yellow") - self.print_text(f"\t{tool_call.function.name}\n", color="yellow") - self.print_text(f"\t{json.dumps(tool_call.function.arguments)}\n", color="yellow") - - self.print_text(f"Model: {result.model}\n", color="yellow") - self.print_text(f"Usage: {result.usage}\n", color="yellow") - self.print_text(f"System Fingerprint: {result.system_fingerprint}\n", color="yellow") - - def on_invoke_error( - self, - llm_instance: AIModel, - ex: Exception, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - ): - """ - Invoke error callback - - :param llm_instance: LLM instance - :param ex: exception - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - self.print_text("\n[on_llm_invoke_error]\n", color="red") - logger.exception(ex) diff --git a/api/dify_graph/model_runtime/entities/__init__.py b/api/dify_graph/model_runtime/entities/__init__.py deleted file mode 100644 index a24e437d48..0000000000 --- a/api/dify_graph/model_runtime/entities/__init__.py +++ /dev/null @@ -1,43 +0,0 @@ -from .llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from .message_entities import ( - AssistantPromptMessage, - AudioPromptMessageContent, - DocumentPromptMessageContent, - ImagePromptMessageContent, - MultiModalPromptMessageContent, - PromptMessage, - PromptMessageContent, - PromptMessageContentType, - PromptMessageRole, - PromptMessageTool, - SystemPromptMessage, - TextPromptMessageContent, - ToolPromptMessage, - UserPromptMessage, - VideoPromptMessageContent, -) -from .model_entities import ModelPropertyKey - -__all__ = [ - "AssistantPromptMessage", - "AudioPromptMessageContent", - "DocumentPromptMessageContent", - "ImagePromptMessageContent", - "LLMMode", - "LLMResult", - "LLMResultChunk", - "LLMResultChunkDelta", - "LLMUsage", - "ModelPropertyKey", - "MultiModalPromptMessageContent", - "PromptMessage", - "PromptMessageContent", - "PromptMessageContentType", - "PromptMessageRole", - "PromptMessageTool", - "SystemPromptMessage", - "TextPromptMessageContent", - "ToolPromptMessage", - "UserPromptMessage", - "VideoPromptMessageContent", -] diff --git a/api/dify_graph/model_runtime/entities/common_entities.py b/api/dify_graph/model_runtime/entities/common_entities.py deleted file mode 100644 index b673efae22..0000000000 --- a/api/dify_graph/model_runtime/entities/common_entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from pydantic import BaseModel, model_validator - - -class I18nObject(BaseModel): - """ - Model class for i18n object. - """ - - zh_Hans: str | None = None - en_US: str - - @model_validator(mode="after") - def _(self): - if not self.zh_Hans: - self.zh_Hans = self.en_US - return self diff --git a/api/dify_graph/model_runtime/entities/defaults.py b/api/dify_graph/model_runtime/entities/defaults.py deleted file mode 100644 index 53b732e5c6..0000000000 --- a/api/dify_graph/model_runtime/entities/defaults.py +++ /dev/null @@ -1,130 +0,0 @@ -from dify_graph.model_runtime.entities.model_entities import DefaultParameterName - -PARAMETER_RULE_TEMPLATE: dict[DefaultParameterName, dict] = { - DefaultParameterName.TEMPERATURE: { - "label": { - "en_US": "Temperature", - "zh_Hans": "温度", - }, - "type": "float", - "help": { - "en_US": "Controls randomness. Lower temperature results in less random completions." - " As the temperature approaches zero, the model will become deterministic and repetitive." - " Higher temperature results in more random completions.", - "zh_Hans": "温度控制随机性。较低的温度会导致较少的随机完成。随着温度接近零,模型将变得确定性和重复性。" - "较高的温度会导致更多的随机完成。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_P: { - "label": { - "en_US": "Top P", - "zh_Hans": "Top P", - }, - "type": "float", - "help": { - "en_US": "Controls diversity via nucleus sampling: 0.5 means half of all likelihood-weighted options" - " are considered.", - "zh_Hans": "通过核心采样控制多样性:0.5 表示考虑了一半的所有可能性加权选项。", - }, - "required": False, - "default": 1.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.TOP_K: { - "label": { - "en_US": "Top K", - "zh_Hans": "Top K", - }, - "type": "int", - "help": { - "en_US": "Limits the number of tokens to consider for each step by keeping only the k most likely tokens.", - "zh_Hans": "通过只保留每一步中最可能的 k 个标记来限制要考虑的标记数量。", - }, - "required": False, - "default": 50, - "min": 1, - "max": 100, - "precision": 0, - }, - DefaultParameterName.PRESENCE_PENALTY: { - "label": { - "en_US": "Presence Penalty", - "zh_Hans": "存在惩罚", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens already in the text.", - "zh_Hans": "对文本中已有的标记的对数概率施加惩罚。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.FREQUENCY_PENALTY: { - "label": { - "en_US": "Frequency Penalty", - "zh_Hans": "频率惩罚", - }, - "type": "float", - "help": { - "en_US": "Applies a penalty to the log-probability of tokens that appear in the text.", - "zh_Hans": "对文本中出现的标记的对数概率施加惩罚。", - }, - "required": False, - "default": 0.0, - "min": 0.0, - "max": 1.0, - "precision": 2, - }, - DefaultParameterName.MAX_TOKENS: { - "label": { - "en_US": "Max Tokens", - "zh_Hans": "最大 Token 数", - }, - "type": "int", - "help": { - "en_US": "Specifies the upper limit on the length of generated results." - " If the generated results are truncated, you can increase this parameter.", - "zh_Hans": "指定生成结果长度的上限。如果生成结果截断,可以调大该参数。", - }, - "required": False, - "default": 64, - "min": 1, - "max": 2048, - "precision": 0, - }, - DefaultParameterName.RESPONSE_FORMAT: { - "label": { - "en_US": "Response Format", - "zh_Hans": "回复格式", - }, - "type": "string", - "help": { - "en_US": "Set a response format, ensure the output from llm is a valid code block as possible," - " such as JSON, XML, etc.", - "zh_Hans": "设置一个返回格式,确保 llm 的输出尽可能是有效的代码块,如 JSON、XML 等", - }, - "required": False, - "options": ["JSON", "XML"], - }, - DefaultParameterName.JSON_SCHEMA: { - "label": { - "en_US": "JSON Schema", - }, - "type": "text", - "help": { - "en_US": "Set a response json schema will ensure LLM to adhere it.", - "zh_Hans": "设置返回的 json schema,llm 将按照它返回", - }, - "required": False, - }, -} diff --git a/api/dify_graph/model_runtime/entities/llm_entities.py b/api/dify_graph/model_runtime/entities/llm_entities.py deleted file mode 100644 index eec682a2ae..0000000000 --- a/api/dify_graph/model_runtime/entities/llm_entities.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from decimal import Decimal -from enum import StrEnum -from typing import Any, TypedDict, Union - -from pydantic import BaseModel, Field - -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelUsage, PriceInfo - - -class LLMMode(StrEnum): - """ - Enum class for large language model mode. - """ - - COMPLETION = "completion" - CHAT = "chat" - - -class LLMUsageMetadata(TypedDict, total=False): - """ - TypedDict for LLM usage metadata. - All fields are optional. - """ - - prompt_tokens: int - completion_tokens: int - total_tokens: int - prompt_unit_price: Union[float, str] - completion_unit_price: Union[float, str] - total_price: Union[float, str] - currency: str - prompt_price_unit: Union[float, str] - completion_price_unit: Union[float, str] - prompt_price: Union[float, str] - completion_price: Union[float, str] - latency: float - time_to_first_token: float - time_to_generate: float - - -class LLMUsage(ModelUsage): - """ - Model class for llm usage. - """ - - prompt_tokens: int - prompt_unit_price: Decimal - prompt_price_unit: Decimal - prompt_price: Decimal - completion_tokens: int - completion_unit_price: Decimal - completion_price_unit: Decimal - completion_price: Decimal - total_tokens: int - total_price: Decimal - currency: str - latency: float - time_to_first_token: float | None = None - time_to_generate: float | None = None - - @classmethod - def empty_usage(cls): - return cls( - prompt_tokens=0, - prompt_unit_price=Decimal("0.0"), - prompt_price_unit=Decimal("0.0"), - prompt_price=Decimal("0.0"), - completion_tokens=0, - completion_unit_price=Decimal("0.0"), - completion_price_unit=Decimal("0.0"), - completion_price=Decimal("0.0"), - total_tokens=0, - total_price=Decimal("0.0"), - currency="USD", - latency=0.0, - time_to_first_token=None, - time_to_generate=None, - ) - - @classmethod - def from_metadata(cls, metadata: LLMUsageMetadata) -> LLMUsage: - """ - Create LLMUsage instance from metadata dictionary with default values. - - Args: - metadata: TypedDict containing usage metadata - - Returns: - LLMUsage instance with values from metadata or defaults - """ - prompt_tokens = metadata.get("prompt_tokens", 0) - completion_tokens = metadata.get("completion_tokens", 0) - total_tokens = metadata.get("total_tokens", 0) - - # If total_tokens is not provided but prompt and completion tokens are, - # calculate total_tokens - if total_tokens == 0 and (prompt_tokens > 0 or completion_tokens > 0): - total_tokens = prompt_tokens + completion_tokens - - return cls( - prompt_tokens=prompt_tokens, - completion_tokens=completion_tokens, - total_tokens=total_tokens, - prompt_unit_price=Decimal(str(metadata.get("prompt_unit_price", 0))), - completion_unit_price=Decimal(str(metadata.get("completion_unit_price", 0))), - total_price=Decimal(str(metadata.get("total_price", 0))), - currency=metadata.get("currency", "USD"), - prompt_price_unit=Decimal(str(metadata.get("prompt_price_unit", 0))), - completion_price_unit=Decimal(str(metadata.get("completion_price_unit", 0))), - prompt_price=Decimal(str(metadata.get("prompt_price", 0))), - completion_price=Decimal(str(metadata.get("completion_price", 0))), - latency=metadata.get("latency", 0.0), - time_to_first_token=metadata.get("time_to_first_token"), - time_to_generate=metadata.get("time_to_generate"), - ) - - def plus(self, other: LLMUsage) -> LLMUsage: - """ - Add two LLMUsage instances together. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - if self.total_tokens == 0: - return other - else: - return LLMUsage( - prompt_tokens=self.prompt_tokens + other.prompt_tokens, - prompt_unit_price=other.prompt_unit_price, - prompt_price_unit=other.prompt_price_unit, - prompt_price=self.prompt_price + other.prompt_price, - completion_tokens=self.completion_tokens + other.completion_tokens, - completion_unit_price=other.completion_unit_price, - completion_price_unit=other.completion_price_unit, - completion_price=self.completion_price + other.completion_price, - total_tokens=self.total_tokens + other.total_tokens, - total_price=self.total_price + other.total_price, - currency=other.currency, - latency=self.latency + other.latency, - time_to_first_token=other.time_to_first_token, - time_to_generate=other.time_to_generate, - ) - - def __add__(self, other: LLMUsage) -> LLMUsage: - """ - Overload the + operator to add two LLMUsage instances. - - :param other: Another LLMUsage instance to add - :return: A new LLMUsage instance with summed values - """ - return self.plus(other) - - -class LLMResult(BaseModel): - """ - Model class for llm result. - """ - - id: str | None = None - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - message: AssistantPromptMessage - usage: LLMUsage - system_fingerprint: str | None = None - reasoning_content: str | None = None - - -class LLMStructuredOutput(BaseModel): - """ - Model class for llm structured output. - """ - - structured_output: Mapping[str, Any] | None = None - - -class LLMResultWithStructuredOutput(LLMResult, LLMStructuredOutput): - """ - Model class for llm result with structured output. - """ - - -class LLMResultChunkDelta(BaseModel): - """ - Model class for llm result chunk delta. - """ - - index: int - message: AssistantPromptMessage - usage: LLMUsage | None = None - finish_reason: str | None = None - - -class LLMResultChunk(BaseModel): - """ - Model class for llm result chunk. - """ - - model: str - prompt_messages: Sequence[PromptMessage] = Field(default_factory=list) - system_fingerprint: str | None = None - delta: LLMResultChunkDelta - - -class LLMResultChunkWithStructuredOutput(LLMResultChunk, LLMStructuredOutput): - """ - Model class for llm result chunk with structured output. - """ - - -class NumTokensResult(PriceInfo): - """ - Model class for number of tokens result. - """ - - tokens: int diff --git a/api/dify_graph/model_runtime/entities/message_entities.py b/api/dify_graph/model_runtime/entities/message_entities.py deleted file mode 100644 index 402bfdc606..0000000000 --- a/api/dify_graph/model_runtime/entities/message_entities.py +++ /dev/null @@ -1,279 +0,0 @@ -from __future__ import annotations - -from abc import ABC -from collections.abc import Mapping, Sequence -from enum import StrEnum, auto -from typing import Annotated, Any, Literal, Union - -from pydantic import BaseModel, Field, field_serializer, field_validator - - -class PromptMessageRole(StrEnum): - """ - Enum class for prompt message. - """ - - SYSTEM = auto() - USER = auto() - ASSISTANT = auto() - TOOL = auto() - - @classmethod - def value_of(cls, value: str) -> PromptMessageRole: - """ - Get value of given mode. - - :param value: mode value - :return: mode - """ - for mode in cls: - if mode.value == value: - return mode - raise ValueError(f"invalid prompt message type value {value}") - - -class PromptMessageTool(BaseModel): - """ - Model class for prompt message tool. - """ - - name: str - description: str - parameters: dict - - -class PromptMessageFunction(BaseModel): - """ - Model class for prompt message function. - """ - - type: str = "function" - function: PromptMessageTool - - -class PromptMessageContentType(StrEnum): - """ - Enum class for prompt message content type. - """ - - TEXT = auto() - IMAGE = auto() - AUDIO = auto() - VIDEO = auto() - DOCUMENT = auto() - - -class PromptMessageContent(ABC, BaseModel): - """ - Model class for prompt message content. - """ - - type: PromptMessageContentType - - -class TextPromptMessageContent(PromptMessageContent): - """ - Model class for text prompt message content. - """ - - type: Literal[PromptMessageContentType.TEXT] = PromptMessageContentType.TEXT # type: ignore - data: str - - -class MultiModalPromptMessageContent(PromptMessageContent): - """ - Model class for multi-modal prompt message content. - """ - - format: str = Field(default=..., description="the format of multi-modal file") - base64_data: str = Field(default="", description="the base64 data of multi-modal file") - url: str = Field(default="", description="the url of multi-modal file") - mime_type: str = Field(default=..., description="the mime type of multi-modal file") - filename: str = Field(default="", description="the filename of multi-modal file") - - @property - def data(self): - return self.url or f"data:{self.mime_type};base64,{self.base64_data}" - - -class VideoPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.VIDEO] = PromptMessageContentType.VIDEO # type: ignore - - -class AudioPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.AUDIO] = PromptMessageContentType.AUDIO # type: ignore - - -class ImagePromptMessageContent(MultiModalPromptMessageContent): - """ - Model class for image prompt message content. - """ - - class DETAIL(StrEnum): - LOW = auto() - HIGH = auto() - - type: Literal[PromptMessageContentType.IMAGE] = PromptMessageContentType.IMAGE # type: ignore - detail: DETAIL = DETAIL.LOW - - -class DocumentPromptMessageContent(MultiModalPromptMessageContent): - type: Literal[PromptMessageContentType.DOCUMENT] = PromptMessageContentType.DOCUMENT # type: ignore - - -PromptMessageContentUnionTypes = Annotated[ - Union[ - TextPromptMessageContent, - ImagePromptMessageContent, - DocumentPromptMessageContent, - AudioPromptMessageContent, - VideoPromptMessageContent, - ], - Field(discriminator="type"), -] - - -CONTENT_TYPE_MAPPING: Mapping[PromptMessageContentType, type[PromptMessageContent]] = { - PromptMessageContentType.TEXT: TextPromptMessageContent, - PromptMessageContentType.IMAGE: ImagePromptMessageContent, - PromptMessageContentType.AUDIO: AudioPromptMessageContent, - PromptMessageContentType.VIDEO: VideoPromptMessageContent, - PromptMessageContentType.DOCUMENT: DocumentPromptMessageContent, -} - - -class PromptMessage(ABC, BaseModel): - """ - Model class for prompt message. - """ - - role: PromptMessageRole - content: str | list[PromptMessageContentUnionTypes] | None = None - name: str | None = None - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return not self.content - - def get_text_content(self) -> str: - """ - Get text content from prompt message. - - :return: Text content as string, empty string if no text content - """ - if isinstance(self.content, str): - return self.content - elif isinstance(self.content, list): - text_parts = [] - for item in self.content: - if isinstance(item, TextPromptMessageContent): - text_parts.append(item.data) - return "".join(text_parts) - else: - return "" - - @field_validator("content", mode="before") - @classmethod - def validate_content(cls, v): - if isinstance(v, list): - prompts = [] - for prompt in v: - if isinstance(prompt, PromptMessageContent): - if not isinstance(prompt, TextPromptMessageContent | MultiModalPromptMessageContent): - prompt = CONTENT_TYPE_MAPPING[prompt.type].model_validate(prompt.model_dump()) - elif isinstance(prompt, dict): - prompt = CONTENT_TYPE_MAPPING[prompt["type"]].model_validate(prompt) - else: - raise ValueError(f"invalid prompt message {prompt}") - prompts.append(prompt) - return prompts - return v - - @field_serializer("content") - def serialize_content( - self, content: Union[str, Sequence[PromptMessageContent]] | None - ) -> str | list[dict[str, Any] | PromptMessageContent] | Sequence[PromptMessageContent] | None: - if content is None or isinstance(content, str): - return content - if isinstance(content, list): - return [item.model_dump() if hasattr(item, "model_dump") else item for item in content] - return content - - -class UserPromptMessage(PromptMessage): - """ - Model class for user prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.USER - - -class AssistantPromptMessage(PromptMessage): - """ - Model class for assistant prompt message. - """ - - class ToolCall(BaseModel): - """ - Model class for assistant prompt message tool call. - """ - - class ToolCallFunction(BaseModel): - """ - Model class for assistant prompt message tool call function. - """ - - name: str - arguments: str - - id: str - type: str - function: ToolCallFunction - - @field_validator("id", mode="before") - @classmethod - def transform_id_to_str(cls, value) -> str: - if not isinstance(value, str): - return str(value) - else: - return value - - role: PromptMessageRole = PromptMessageRole.ASSISTANT - tool_calls: list[ToolCall] = [] - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_calls - - -class SystemPromptMessage(PromptMessage): - """ - Model class for system prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.SYSTEM - - -class ToolPromptMessage(PromptMessage): - """ - Model class for tool prompt message. - """ - - role: PromptMessageRole = PromptMessageRole.TOOL - tool_call_id: str - - def is_empty(self) -> bool: - """ - Check if prompt message is empty. - - :return: True if prompt message is empty, False otherwise - """ - return super().is_empty() and not self.tool_call_id diff --git a/api/dify_graph/model_runtime/entities/model_entities.py b/api/dify_graph/model_runtime/entities/model_entities.py deleted file mode 100644 index fbcde6740a..0000000000 --- a/api/dify_graph/model_runtime/entities/model_entities.py +++ /dev/null @@ -1,242 +0,0 @@ -from __future__ import annotations - -from decimal import Decimal -from enum import StrEnum, auto -from typing import Any - -from pydantic import BaseModel, ConfigDict, model_validator - -from dify_graph.model_runtime.entities.common_entities import I18nObject - - -class ModelType(StrEnum): - """ - Enum class for model type. - """ - - LLM = auto() - TEXT_EMBEDDING = "text-embedding" - RERANK = auto() - SPEECH2TEXT = auto() - MODERATION = auto() - TTS = auto() - - @classmethod - def value_of(cls, origin_model_type: str) -> ModelType: - """ - Get model type from origin model type. - - :return: model type - """ - if origin_model_type in {"text-generation", cls.LLM}: - return cls.LLM - elif origin_model_type in {"embeddings", cls.TEXT_EMBEDDING}: - return cls.TEXT_EMBEDDING - elif origin_model_type in {"reranking", cls.RERANK}: - return cls.RERANK - elif origin_model_type in {"speech2text", cls.SPEECH2TEXT}: - return cls.SPEECH2TEXT - elif origin_model_type in {"tts", cls.TTS}: - return cls.TTS - elif origin_model_type == cls.MODERATION: - return cls.MODERATION - else: - raise ValueError(f"invalid origin model type {origin_model_type}") - - def to_origin_model_type(self) -> str: - """ - Get origin model type from model type. - - :return: origin model type - """ - if self == self.LLM: - return "text-generation" - elif self == self.TEXT_EMBEDDING: - return "embeddings" - elif self == self.RERANK: - return "reranking" - elif self == self.SPEECH2TEXT: - return "speech2text" - elif self == self.TTS: - return "tts" - elif self == self.MODERATION: - return "moderation" - else: - raise ValueError(f"invalid model type {self}") - - -class FetchFrom(StrEnum): - """ - Enum class for fetch from. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class ModelFeature(StrEnum): - """ - Enum class for llm feature. - """ - - TOOL_CALL = "tool-call" - MULTI_TOOL_CALL = "multi-tool-call" - AGENT_THOUGHT = "agent-thought" - VISION = auto() - STREAM_TOOL_CALL = "stream-tool-call" - DOCUMENT = auto() - VIDEO = auto() - AUDIO = auto() - STRUCTURED_OUTPUT = "structured-output" - - -class DefaultParameterName(StrEnum): - """ - Enum class for parameter template variable. - """ - - TEMPERATURE = auto() - TOP_P = auto() - TOP_K = auto() - PRESENCE_PENALTY = auto() - FREQUENCY_PENALTY = auto() - MAX_TOKENS = auto() - RESPONSE_FORMAT = auto() - JSON_SCHEMA = auto() - - @classmethod - def value_of(cls, value: Any) -> DefaultParameterName: - """ - Get parameter name from value. - - :param value: parameter value - :return: parameter name - """ - for name in cls: - if name.value == value: - return name - raise ValueError(f"invalid parameter name {value}") - - -class ParameterType(StrEnum): - """ - Enum class for parameter type. - """ - - FLOAT = auto() - INT = auto() - STRING = auto() - BOOLEAN = auto() - TEXT = auto() - - -class ModelPropertyKey(StrEnum): - """ - Enum class for model property key. - """ - - MODE = auto() - CONTEXT_SIZE = auto() - MAX_CHUNKS = auto() - FILE_UPLOAD_LIMIT = auto() - SUPPORTED_FILE_EXTENSIONS = auto() - MAX_CHARACTERS_PER_CHUNK = auto() - DEFAULT_VOICE = auto() - VOICES = auto() - WORD_LIMIT = auto() - AUDIO_TYPE = auto() - MAX_WORKERS = auto() - - -class ProviderModel(BaseModel): - """ - Model class for provider model. - """ - - model: str - label: I18nObject - model_type: ModelType - features: list[ModelFeature] | None = None - fetch_from: FetchFrom - model_properties: dict[ModelPropertyKey, Any] - deprecated: bool = False - model_config = ConfigDict(protected_namespaces=()) - - @property - def support_structure_output(self) -> bool: - return self.features is not None and ModelFeature.STRUCTURED_OUTPUT in self.features - - -class ParameterRule(BaseModel): - """ - Model class for parameter rule. - """ - - name: str - use_template: str | None = None - label: I18nObject - type: ParameterType - help: I18nObject | None = None - required: bool = False - default: Any | None = None - min: float | None = None - max: float | None = None - precision: int | None = None - options: list[str] = [] - - -class PriceConfig(BaseModel): - """ - Model class for pricing info. - """ - - input: Decimal - output: Decimal | None = None - unit: Decimal - currency: str - - -class AIModelEntity(ProviderModel): - """ - Model class for AI model. - """ - - parameter_rules: list[ParameterRule] = [] - pricing: PriceConfig | None = None - - @model_validator(mode="after") - def validate_model(self): - supported_schema_keys = ["json_schema"] - schema_key = next((rule.name for rule in self.parameter_rules if rule.name in supported_schema_keys), None) - if not schema_key: - return self - if self.features is None: - self.features = [ModelFeature.STRUCTURED_OUTPUT] - else: - if ModelFeature.STRUCTURED_OUTPUT not in self.features: - self.features.append(ModelFeature.STRUCTURED_OUTPUT) - return self - - -class ModelUsage(BaseModel): - pass - - -class PriceType(StrEnum): - """ - Enum class for price type. - """ - - INPUT = auto() - OUTPUT = auto() - - -class PriceInfo(BaseModel): - """ - Model class for price info. - """ - - unit_price: Decimal - unit: Decimal - total_amount: Decimal - currency: str diff --git a/api/dify_graph/model_runtime/entities/provider_entities.py b/api/dify_graph/model_runtime/entities/provider_entities.py deleted file mode 100644 index 97a99ea7ce..0000000000 --- a/api/dify_graph/model_runtime/entities/provider_entities.py +++ /dev/null @@ -1,169 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator - -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType - - -class ConfigurateMethod(StrEnum): - """ - Enum class for configurate method of provider model. - """ - - PREDEFINED_MODEL = "predefined-model" - CUSTOMIZABLE_MODEL = "customizable-model" - - -class FormType(StrEnum): - """ - Enum class for form type. - """ - - TEXT_INPUT = "text-input" - SECRET_INPUT = "secret-input" - SELECT = auto() - RADIO = auto() - SWITCH = auto() - - -class FormShowOnObject(BaseModel): - """ - Model class for form show on. - """ - - variable: str - value: str - - -class FormOption(BaseModel): - """ - Model class for form option. - """ - - label: I18nObject - value: str - show_on: list[FormShowOnObject] = [] - - @model_validator(mode="after") - def _(self): - if not self.label: - self.label = I18nObject(en_US=self.value) - return self - - -class CredentialFormSchema(BaseModel): - """ - Model class for credential form schema. - """ - - variable: str - label: I18nObject - type: FormType - required: bool = True - default: str | None = None - options: list[FormOption] | None = None - placeholder: I18nObject | None = None - max_length: int = 0 - show_on: list[FormShowOnObject] = [] - - -class ProviderCredentialSchema(BaseModel): - """ - Model class for provider credential schema. - """ - - credential_form_schemas: list[CredentialFormSchema] - - -class FieldModelSchema(BaseModel): - label: I18nObject - placeholder: I18nObject | None = None - - -class ModelCredentialSchema(BaseModel): - """ - Model class for model credential schema. - """ - - model: FieldModelSchema - credential_form_schemas: list[CredentialFormSchema] - - -class SimpleProviderEntity(BaseModel): - """ - Simple model class for provider. - """ - - provider: str - label: I18nObject - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - supported_model_types: Sequence[ModelType] - models: list[AIModelEntity] = [] - - -class ProviderHelpEntity(BaseModel): - """ - Model class for provider help. - """ - - title: I18nObject - url: I18nObject - - -class ProviderEntity(BaseModel): - """ - Model class for provider. - """ - - provider: str - label: I18nObject - description: I18nObject | None = None - icon_small: I18nObject | None = None - icon_small_dark: I18nObject | None = None - background: str | None = None - help: ProviderHelpEntity | None = None - supported_model_types: Sequence[ModelType] - configurate_methods: list[ConfigurateMethod] - models: list[AIModelEntity] = Field(default_factory=list) - provider_credential_schema: ProviderCredentialSchema | None = None - model_credential_schema: ModelCredentialSchema | None = None - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - # position from plugin _position.yaml - position: dict[str, list[str]] | None = {} - - @field_validator("models", mode="before") - @classmethod - def validate_models(cls, v): - # returns EmptyList if v is empty - if not v: - return [] - return v - - def to_simple_provider(self) -> SimpleProviderEntity: - """ - Convert to simple provider. - - :return: simple provider - """ - return SimpleProviderEntity( - provider=self.provider, - label=self.label, - icon_small=self.icon_small, - supported_model_types=self.supported_model_types, - models=self.models, - ) - - -class ProviderConfig(BaseModel): - """ - Model class for provider config. - """ - - provider: str - credentials: dict diff --git a/api/dify_graph/model_runtime/entities/rerank_entities.py b/api/dify_graph/model_runtime/entities/rerank_entities.py deleted file mode 100644 index 99709e1bcd..0000000000 --- a/api/dify_graph/model_runtime/entities/rerank_entities.py +++ /dev/null @@ -1,20 +0,0 @@ -from pydantic import BaseModel - - -class RerankDocument(BaseModel): - """ - Model class for rerank document. - """ - - index: int - text: str - score: float - - -class RerankResult(BaseModel): - """ - Model class for rerank result. - """ - - model: str - docs: list[RerankDocument] diff --git a/api/dify_graph/model_runtime/entities/text_embedding_entities.py b/api/dify_graph/model_runtime/entities/text_embedding_entities.py deleted file mode 100644 index a0210c169d..0000000000 --- a/api/dify_graph/model_runtime/entities/text_embedding_entities.py +++ /dev/null @@ -1,39 +0,0 @@ -from decimal import Decimal - -from pydantic import BaseModel - -from dify_graph.model_runtime.entities.model_entities import ModelUsage - - -class EmbeddingUsage(ModelUsage): - """ - Model class for embedding usage. - """ - - tokens: int - total_tokens: int - unit_price: Decimal - price_unit: Decimal - total_price: Decimal - currency: str - latency: float - - -class EmbeddingResult(BaseModel): - """ - Model class for text embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage - - -class FileEmbeddingResult(BaseModel): - """ - Model class for file embedding result. - """ - - model: str - embeddings: list[list[float]] - usage: EmbeddingUsage diff --git a/api/dify_graph/model_runtime/errors/invoke.py b/api/dify_graph/model_runtime/errors/invoke.py deleted file mode 100644 index 1a57078b98..0000000000 --- a/api/dify_graph/model_runtime/errors/invoke.py +++ /dev/null @@ -1,41 +0,0 @@ -class InvokeError(ValueError): - """Base class for all LLM exceptions.""" - - description: str | None = None - - def __init__(self, description: str | None = None): - if description is not None: - self.description = description - - def __str__(self): - return self.description or self.__class__.__name__ - - -class InvokeConnectionError(InvokeError): - """Raised when the Invoke returns connection error.""" - - description = "Connection Error" - - -class InvokeServerUnavailableError(InvokeError): - """Raised when the Invoke returns server unavailable error.""" - - description = "Server Unavailable Error" - - -class InvokeRateLimitError(InvokeError): - """Raised when the Invoke returns rate limit error.""" - - description = "Rate Limit Error" - - -class InvokeAuthorizationError(InvokeError): - """Raised when the Invoke returns authorization error.""" - - description = "Incorrect model credentials provided, please check and try again. " - - -class InvokeBadRequestError(InvokeError): - """Raised when the Invoke returns bad request.""" - - description = "Bad Request Error" diff --git a/api/dify_graph/model_runtime/errors/validate.py b/api/dify_graph/model_runtime/errors/validate.py deleted file mode 100644 index 16bebcc67d..0000000000 --- a/api/dify_graph/model_runtime/errors/validate.py +++ /dev/null @@ -1,6 +0,0 @@ -class CredentialsValidateFailedError(ValueError): - """ - Credentials validate failed error - """ - - pass diff --git a/api/dify_graph/model_runtime/memory/__init__.py b/api/dify_graph/model_runtime/memory/__init__.py deleted file mode 100644 index 2d954486c3..0000000000 --- a/api/dify_graph/model_runtime/memory/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .prompt_message_memory import DEFAULT_MEMORY_MAX_TOKEN_LIMIT, PromptMessageMemory - -__all__ = ["DEFAULT_MEMORY_MAX_TOKEN_LIMIT", "PromptMessageMemory"] diff --git a/api/dify_graph/model_runtime/memory/prompt_message_memory.py b/api/dify_graph/model_runtime/memory/prompt_message_memory.py deleted file mode 100644 index a76a7faf71..0000000000 --- a/api/dify_graph/model_runtime/memory/prompt_message_memory.py +++ /dev/null @@ -1,18 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from typing import Protocol - -from dify_graph.model_runtime.entities import PromptMessage - -DEFAULT_MEMORY_MAX_TOKEN_LIMIT = 2000 - - -class PromptMessageMemory(Protocol): - """Port for loading memory as prompt messages.""" - - def get_history_prompt_messages( - self, max_token_limit: int = DEFAULT_MEMORY_MAX_TOKEN_LIMIT, message_limit: int | None = None - ) -> Sequence[PromptMessage]: - """Return historical prompt messages constrained by token/message limits.""" - ... diff --git a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py b/api/dify_graph/model_runtime/model_providers/__base/ai_model.py deleted file mode 100644 index ac7ae9925b..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/ai_model.py +++ /dev/null @@ -1,286 +0,0 @@ -import decimal -import hashlib -import logging - -from pydantic import BaseModel, ConfigDict, Field, ValidationError -from redis import RedisError - -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE -from dify_graph.model_runtime.entities.model_entities import ( - AIModelEntity, - DefaultParameterName, - ModelType, - PriceConfig, - PriceInfo, - PriceType, -) -from dify_graph.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from extensions.ext_redis import redis_client - -logger = logging.getLogger(__name__) - - -class AIModel(BaseModel): - """ - Base class for all models. - """ - - tenant_id: str = Field(description="Tenant ID") - model_type: ModelType = Field(description="Model type") - plugin_id: str = Field(description="Plugin ID") - provider_name: str = Field(description="Provider") - plugin_model_provider: PluginModelProviderEntity = Field(description="Plugin model provider") - started_at: float = Field(description="Invoke start time", default=0) - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - @property - def _invoke_error_mapping(self) -> dict[type[Exception], list[type[Exception]]]: - """ - Map model invoke error to unified error - The key is the error type thrown to the caller - The value is the error type thrown by the model, - which needs to be converted into a unified error type for the caller. - - :return: Invoke error mapping - """ - from core.plugin.entities.plugin_daemon import PluginDaemonInnerError - - return { - InvokeConnectionError: [InvokeConnectionError], - InvokeServerUnavailableError: [InvokeServerUnavailableError], - InvokeRateLimitError: [InvokeRateLimitError], - InvokeAuthorizationError: [InvokeAuthorizationError], - InvokeBadRequestError: [InvokeBadRequestError], - PluginDaemonInnerError: [PluginDaemonInnerError], - ValueError: [ValueError], - } - - def _transform_invoke_error(self, error: Exception) -> Exception: - """ - Transform invoke error to unified error - - :param error: model invoke error - :return: unified error - """ - for invoke_error, model_errors in self._invoke_error_mapping.items(): - if isinstance(error, tuple(model_errors)): - if invoke_error == InvokeAuthorizationError: - return InvokeAuthorizationError( - description=( - f"[{self.provider_name}] Incorrect model credentials provided, please check and try again." - ) - ) - elif isinstance(invoke_error, InvokeError): - return InvokeError(description=f"[{self.provider_name}] {invoke_error.description}, {str(error)}") - else: - return error - - return InvokeError(description=f"[{self.provider_name}] Error: {str(error)}") - - def get_price(self, model: str, credentials: dict, price_type: PriceType, tokens: int) -> PriceInfo: - """ - Get price for given model and tokens - - :param model: model name - :param credentials: model credentials - :param price_type: price type - :param tokens: number of tokens - :return: price info - """ - # get model schema - model_schema = self.get_model_schema(model, credentials) - - # get price info from predefined model schema - price_config: PriceConfig | None = None - if model_schema and model_schema.pricing: - price_config = model_schema.pricing - - # get unit price - unit_price = None - if price_config: - if price_type == PriceType.INPUT: - unit_price = price_config.input - elif price_type == PriceType.OUTPUT and price_config.output is not None: - unit_price = price_config.output - - if unit_price is None: - return PriceInfo( - unit_price=decimal.Decimal("0.0"), - unit=decimal.Decimal("0.0"), - total_amount=decimal.Decimal("0.0"), - currency="USD", - ) - - # calculate total amount - if not price_config: - raise ValueError(f"Price config not found for model {model}") - total_amount = tokens * unit_price * price_config.unit - total_amount = total_amount.quantize(decimal.Decimal("0.0000001"), rounding=decimal.ROUND_HALF_UP) - - return PriceInfo( - unit_price=unit_price, - unit=price_config.unit, - total_amount=total_amount, - currency=price_config.currency, - ) - - def get_model_schema(self, model: str, credentials: dict | None = None) -> AIModelEntity | None: - """ - Get model schema by model name and credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - cache_key = f"{self.tenant_id}:{self.plugin_id}:{self.provider_name}:{self.model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials or {}, - ) - - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - - def get_customizable_model_schema_from_credentials(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema from credentials - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - - # get customizable model schema - schema = self.get_customizable_model_schema(model, credentials) - if not schema: - return None - - # fill in the template - new_parameter_rules = [] - for parameter_rule in schema.parameter_rules: - if parameter_rule.use_template: - try: - default_parameter_name = DefaultParameterName.value_of(parameter_rule.use_template) - default_parameter_rule = self._get_default_parameter_rule_variable_map(default_parameter_name) - if not parameter_rule.max and "max" in default_parameter_rule: - parameter_rule.max = default_parameter_rule["max"] - if not parameter_rule.min and "min" in default_parameter_rule: - parameter_rule.min = default_parameter_rule["min"] - if not parameter_rule.default and "default" in default_parameter_rule: - parameter_rule.default = default_parameter_rule["default"] - if not parameter_rule.precision and "precision" in default_parameter_rule: - parameter_rule.precision = default_parameter_rule["precision"] - if not parameter_rule.required and "required" in default_parameter_rule: - parameter_rule.required = default_parameter_rule["required"] - if not parameter_rule.help and "help" in default_parameter_rule: - parameter_rule.help = I18nObject( - en_US=default_parameter_rule["help"]["en_US"], - ) - if ( - parameter_rule.help - and not parameter_rule.help.en_US - and ("help" in default_parameter_rule and "en_US" in default_parameter_rule["help"]) - ): - parameter_rule.help.en_US = default_parameter_rule["help"]["en_US"] - if ( - parameter_rule.help - and not parameter_rule.help.zh_Hans - and ("help" in default_parameter_rule and "zh_Hans" in default_parameter_rule["help"]) - ): - parameter_rule.help.zh_Hans = default_parameter_rule["help"].get( - "zh_Hans", default_parameter_rule["help"]["en_US"] - ) - except ValueError: - pass - - new_parameter_rules.append(parameter_rule) - - schema.parameter_rules = new_parameter_rules - - return schema - - def get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None: - """ - Get customizable model schema - - :param model: model name - :param credentials: model credentials - :return: model schema - """ - return None - - def _get_default_parameter_rule_variable_map(self, name: DefaultParameterName): - """ - Get default parameter rule for given name - - :param name: parameter name - :return: parameter rule - """ - default_parameter_rule = PARAMETER_RULE_TEMPLATE.get(name) - - if not default_parameter_rule: - raise Exception(f"Invalid model parameter rule name {name}") - - return default_parameter_rule diff --git a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py b/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py deleted file mode 100644 index bf864ca227..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/large_language_model.py +++ /dev/null @@ -1,668 +0,0 @@ -import logging -import time -import uuid -from collections.abc import Callable, Generator, Iterator, Sequence -from typing import Union - -from pydantic import ConfigDict - -from configs import dify_config -from dify_graph.model_runtime.callbacks.base_callback import Callback -from dify_graph.model_runtime.callbacks.logging_callback import LoggingCallback -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageContentUnionTypes, - PromptMessageTool, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.model_entities import ( - ModelType, - PriceType, -) -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -def _gen_tool_call_id() -> str: - return f"chatcmpl-tool-{str(uuid.uuid4().hex)}" - - -def _run_callbacks(callbacks: Sequence[Callback] | None, *, event: str, invoke: Callable[[Callback], None]) -> None: - if not callbacks: - return - - for callback in callbacks: - try: - invoke(callback) - except Exception as e: - if callback.raise_error: - raise - logger.warning("Callback %s %s failed with error %s", callback.__class__.__name__, event, e) - - -def _get_or_create_tool_call( - existing_tools_calls: list[AssistantPromptMessage.ToolCall], - tool_call_id: str, -) -> AssistantPromptMessage.ToolCall: - """ - Get or create a tool call by ID. - - If `tool_call_id` is empty, returns the most recently created tool call. - """ - if not tool_call_id: - if not existing_tools_calls: - raise ValueError("tool_call_id is empty but no existing tool call is available to apply the delta") - return existing_tools_calls[-1] - - tool_call = next((tool_call for tool_call in existing_tools_calls if tool_call.id == tool_call_id), None) - if tool_call is None: - tool_call = AssistantPromptMessage.ToolCall( - id=tool_call_id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction(name="", arguments=""), - ) - existing_tools_calls.append(tool_call) - - return tool_call - - -def _merge_tool_call_delta( - tool_call: AssistantPromptMessage.ToolCall, - delta: AssistantPromptMessage.ToolCall, -) -> None: - if delta.id: - tool_call.id = delta.id - if delta.type: - tool_call.type = delta.type - if delta.function.name: - tool_call.function.name = delta.function.name - if delta.function.arguments: - tool_call.function.arguments += delta.function.arguments - - -def _build_llm_result_from_chunks( - model: str, - prompt_messages: Sequence[PromptMessage], - chunks: Iterator[LLMResultChunk], -) -> LLMResult: - """ - Build a single `LLMResult` by accumulating all returned chunks. - - Some models only support streaming output (e.g. Qwen3 open-source edition) - and the plugin side may still implement the response via a chunked stream, - so all chunks must be consumed and concatenated into a single ``LLMResult``. - - The ``usage`` is taken from the last chunk that carries it, which is the - typical convention for streaming responses (the final chunk contains the - aggregated token counts). - """ - content = "" - content_list: list[PromptMessageContentUnionTypes] = [] - usage = LLMUsage.empty_usage() - system_fingerprint: str | None = None - tools_calls: list[AssistantPromptMessage.ToolCall] = [] - - try: - for chunk in chunks: - if isinstance(chunk.delta.message.content, str): - content += chunk.delta.message.content - elif isinstance(chunk.delta.message.content, list): - content_list.extend(chunk.delta.message.content) - - if chunk.delta.message.tool_calls: - _increase_tool_call(chunk.delta.message.tool_calls, tools_calls) - - if chunk.delta.usage: - usage = chunk.delta.usage - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception: - logger.exception("Error while consuming non-stream plugin chunk iterator.") - raise - finally: - # Drain any remaining chunks to release underlying streaming resources (e.g. HTTP connections). - close = getattr(chunks, "close", None) - if callable(close): - close() - - return LLMResult( - model=model, - prompt_messages=prompt_messages, - message=AssistantPromptMessage( - content=content or content_list, - tool_calls=tools_calls, - ), - usage=usage, - system_fingerprint=system_fingerprint, - ) - - -def _invoke_llm_via_plugin( - *, - tenant_id: str, - user_id: str, - plugin_id: str, - provider: str, - model: str, - credentials: dict, - model_parameters: dict, - prompt_messages: Sequence[PromptMessage], - tools: list[PromptMessageTool] | None, - stop: Sequence[str] | None, - stream: bool, -) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_llm( - tenant_id=tenant_id, - user_id=user_id, - plugin_id=plugin_id, - provider=provider, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=list(prompt_messages), - tools=tools, - stop=list(stop) if stop else None, - stream=stream, - ) - - -def _normalize_non_stream_plugin_result( - model: str, - prompt_messages: Sequence[PromptMessage], - result: Union[LLMResult, Iterator[LLMResultChunk]], -) -> LLMResult: - if isinstance(result, LLMResult): - return result - return _build_llm_result_from_chunks(model=model, prompt_messages=prompt_messages, chunks=result) - - -def _increase_tool_call( - new_tool_calls: list[AssistantPromptMessage.ToolCall], existing_tools_calls: list[AssistantPromptMessage.ToolCall] -): - """ - Merge incremental tool call updates into existing tool calls. - - :param new_tool_calls: List of new tool call deltas to be merged. - :param existing_tools_calls: List of existing tool calls to be modified IN-PLACE. - """ - - for new_tool_call in new_tool_calls: - # generate ID for tool calls with function name but no ID to track them - if new_tool_call.function.name and not new_tool_call.id: - new_tool_call.id = _gen_tool_call_id() - - tool_call = _get_or_create_tool_call(existing_tools_calls, new_tool_call.id) - _merge_tool_call_delta(tool_call, new_tool_call) - - -class LargeLanguageModel(AIModel): - """ - Model class for large language model. - """ - - model_type: ModelType = ModelType.LLM - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict | None = None, - tools: list[PromptMessageTool] | None = None, - stop: list[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ) -> Union[LLMResult, Generator[LLMResultChunk, None, None]]: - """ - Invoke large language model - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - :return: full response or stream response chunk generator result - """ - # validate and filter model parameters - if model_parameters is None: - model_parameters = {} - - self.started_at = time.perf_counter() - - callbacks = callbacks or [] - - if dify_config.DEBUG: - callbacks.append(LoggingCallback()) - - # trigger before invoke callbacks - self._trigger_before_invoke_callbacks( - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - - result: Union[LLMResult, Generator[LLMResultChunk, None, None]] - - try: - result = _invoke_llm_via_plugin( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - model_parameters=model_parameters, - prompt_messages=prompt_messages, - tools=tools, - stop=stop, - stream=stream, - ) - - if not stream: - result = _normalize_non_stream_plugin_result( - model=model, prompt_messages=prompt_messages, result=result - ) - except Exception as e: - self._trigger_invoke_error_callbacks( - model=model, - ex=e, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - - # TODO - raise self._transform_invoke_error(e) - - if stream and not isinstance(result, LLMResult): - return self._invoke_result_generator( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - elif isinstance(result, LLMResult): - self._trigger_after_invoke_callbacks( - model=model, - result=result, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - result.prompt_messages = prompt_messages - return result - raise NotImplementedError("unsupported invoke result type", type(result)) - - def _invoke_result_generator( - self, - model: str, - result: Generator[LLMResultChunk, None, None], - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ) -> Generator[LLMResultChunk, None, None]: - """ - Invoke result generator - - :param result: result generator - :return: result generator - """ - callbacks = callbacks or [] - message_content: list[PromptMessageContentUnionTypes] = [] - usage = None - system_fingerprint = None - real_model = model - - def _update_message_content(content: str | list[PromptMessageContentUnionTypes] | None): - if not content: - return - if isinstance(content, list): - message_content.extend(content) - return - if isinstance(content, str): - message_content.append(TextPromptMessageContent(data=content)) - return - - try: - for chunk in result: - # Following https://github.com/langgenius/dify/issues/17799, - # we removed the prompt_messages from the chunk on the plugin daemon side. - # To ensure compatibility, we add the prompt_messages back here. - chunk.prompt_messages = prompt_messages - yield chunk - - self._trigger_new_chunk_callbacks( - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - - _update_message_content(chunk.delta.message.content) - - real_model = chunk.model - if chunk.delta.usage: - usage = chunk.delta.usage - - if chunk.system_fingerprint: - system_fingerprint = chunk.system_fingerprint - except Exception as e: - raise self._transform_invoke_error(e) - - assistant_message = AssistantPromptMessage(content=message_content) - self._trigger_after_invoke_callbacks( - model=model, - result=LLMResult( - model=real_model, - prompt_messages=prompt_messages, - message=assistant_message, - usage=usage or LLMUsage.empty_usage(), - system_fingerprint=system_fingerprint, - ), - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - callbacks=callbacks, - ) - - def get_num_tokens( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool] | None = None, - ) -> int: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param tools: tools for tool calling - :return: - """ - if dify_config.PLUGIN_BASED_TOKEN_COUNTING_ENABLED: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_llm_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model_type=self.model_type.value, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - tools=tools, - ) - return 0 - - def calc_response_usage( - self, model: str, credentials: dict, prompt_tokens: int, completion_tokens: int - ) -> LLMUsage: - """ - Calculate response usage - - :param model: model name - :param credentials: model credentials - :param prompt_tokens: prompt tokens - :param completion_tokens: completion tokens - :return: usage - """ - # get prompt price info - prompt_price_info = self.get_price( - model=model, - credentials=credentials, - price_type=PriceType.INPUT, - tokens=prompt_tokens, - ) - - # get completion price info - completion_price_info = self.get_price( - model=model, credentials=credentials, price_type=PriceType.OUTPUT, tokens=completion_tokens - ) - - # transform usage - usage = LLMUsage( - prompt_tokens=prompt_tokens, - prompt_unit_price=prompt_price_info.unit_price, - prompt_price_unit=prompt_price_info.unit, - prompt_price=prompt_price_info.total_amount, - completion_tokens=completion_tokens, - completion_unit_price=completion_price_info.unit_price, - completion_price_unit=completion_price_info.unit, - completion_price=completion_price_info.total_amount, - total_tokens=prompt_tokens + completion_tokens, - total_price=prompt_price_info.total_amount + completion_price_info.total_amount, - currency=prompt_price_info.currency, - latency=time.perf_counter() - self.started_at, - ) - - return usage - - def _trigger_before_invoke_callbacks( - self, - model: str, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger before invoke callbacks - - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_before_invoke", - invoke=lambda callback: callback.on_before_invoke( - llm_instance=self, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ), - ) - - def _trigger_new_chunk_callbacks( - self, - chunk: LLMResultChunk, - model: str, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger new chunk callbacks - - :param chunk: chunk - :param model: model name - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - """ - _run_callbacks( - callbacks, - event="on_new_chunk", - invoke=lambda callback: callback.on_new_chunk( - llm_instance=self, - chunk=chunk, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ), - ) - - def _trigger_after_invoke_callbacks( - self, - model: str, - result: LLMResult, - credentials: dict, - prompt_messages: Sequence[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger after invoke callbacks - - :param model: model name - :param result: result - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_after_invoke", - invoke=lambda callback: callback.on_after_invoke( - llm_instance=self, - result=result, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ), - ) - - def _trigger_invoke_error_callbacks( - self, - model: str, - ex: Exception, - credentials: dict, - prompt_messages: list[PromptMessage], - model_parameters: dict, - tools: list[PromptMessageTool] | None = None, - stop: Sequence[str] | None = None, - stream: bool = True, - user: str | None = None, - callbacks: list[Callback] | None = None, - ): - """ - Trigger invoke error callbacks - - :param model: model name - :param ex: exception - :param credentials: model credentials - :param prompt_messages: prompt messages - :param model_parameters: model parameters - :param tools: tools for tool calling - :param stop: stop words - :param stream: is stream response - :param user: unique user id - :param callbacks: callbacks - """ - _run_callbacks( - callbacks, - event="on_invoke_error", - invoke=lambda callback: callback.on_invoke_error( - llm_instance=self, - ex=ex, - model=model, - credentials=credentials, - prompt_messages=prompt_messages, - model_parameters=model_parameters, - tools=tools, - stop=stop, - stream=stream, - user=user, - ), - ) diff --git a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py b/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py deleted file mode 100644 index 5fa3d1634b..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/moderation_model.py +++ /dev/null @@ -1,45 +0,0 @@ -import time - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class ModerationModel(AIModel): - """ - Model class for moderation model. - """ - - model_type: ModelType = ModelType.MODERATION - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, text: str, user: str | None = None) -> bool: - """ - Invoke moderation model - - :param model: model name - :param credentials: model credentials - :param text: text to moderate - :param user: unique user id - :return: false if text is safe, true otherwise - """ - self.started_at = time.perf_counter() - - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_moderation( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - text=text, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py b/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py deleted file mode 100644 index 5da2b84b95..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/rerank_model.py +++ /dev/null @@ -1,92 +0,0 @@ -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.rerank_entities import RerankResult -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class RerankModel(AIModel): - """ - Base Model class for rerank model. - """ - - model_type: ModelType = ModelType.RERANK - - def invoke( - self, - model: str, - credentials: dict, - query: str, - docs: list[str], - score_threshold: float | None = None, - top_n: int | None = None, - user: str | None = None, - ) -> RerankResult: - """ - Invoke rerank model - - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :param user: unique user id - :return: rerank result - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def invoke_multimodal_rerank( - self, - model: str, - credentials: dict, - query: dict, - docs: list[dict], - score_threshold: float | None = None, - top_n: int | None = None, - user: str | None = None, - ) -> RerankResult: - """ - Invoke multimodal rerank model - :param model: model name - :param credentials: model credentials - :param query: search query - :param docs: docs for reranking - :param score_threshold: score threshold - :param top_n: top n - :param user: unique user id - :return: rerank result - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_multimodal_rerank( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - query=query, - docs=docs, - score_threshold=score_threshold, - top_n=top_n, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py b/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py deleted file mode 100644 index e69069a85d..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/speech2text_model.py +++ /dev/null @@ -1,43 +0,0 @@ -from typing import IO - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class Speech2TextModel(AIModel): - """ - Model class for speech2text model. - """ - - model_type: ModelType = ModelType.SPEECH2TEXT - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke(self, model: str, credentials: dict, file: IO[bytes], user: str | None = None) -> str: - """ - Invoke speech to text model - - :param model: model name - :param credentials: model credentials - :param file: audio file - :param user: unique user id - :return: text for given audio file - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_speech_to_text( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - file=file, - ) - except Exception as e: - raise self._transform_invoke_error(e) diff --git a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py b/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py deleted file mode 100644 index 3438da2ada..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/text_embedding_model.py +++ /dev/null @@ -1,121 +0,0 @@ -from pydantic import ConfigDict - -from core.entities.embedding_type import EmbeddingInputType -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - - -class TextEmbeddingModel(AIModel): - """ - Model class for text embedding model. - """ - - model_type: ModelType = ModelType.TEXT_EMBEDDING - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke( - self, - model: str, - credentials: dict, - texts: list[str] | None = None, - multimodel_documents: list[dict] | None = None, - user: str | None = None, - input_type: EmbeddingInputType = EmbeddingInputType.DOCUMENT, - ) -> EmbeddingResult: - """ - Invoke text embedding model - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :param files: files to embed - :param user: unique user id - :param input_type: input type - :return: embeddings result - """ - from core.plugin.impl.model import PluginModelClient - - try: - plugin_model_manager = PluginModelClient() - if texts: - return plugin_model_manager.invoke_text_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - input_type=input_type, - ) - if multimodel_documents: - return plugin_model_manager.invoke_multimodal_embedding( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - documents=multimodel_documents, - input_type=input_type, - ) - raise ValueError("No texts or files provided") - except Exception as e: - raise self._transform_invoke_error(e) - - def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> list[int]: - """ - Get number of tokens for given prompt messages - - :param model: model name - :param credentials: model credentials - :param texts: texts to embed - :return: - """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_text_embedding_num_tokens( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - texts=texts, - ) - - def _get_context_size(self, model: str, credentials: dict) -> int: - """ - Get context size for given embedding model - - :param model: model name - :param credentials: model credentials - :return: context size - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.CONTEXT_SIZE in model_schema.model_properties: - content_size: int = model_schema.model_properties[ModelPropertyKey.CONTEXT_SIZE] - return content_size - - return 1000 - - def _get_max_chunks(self, model: str, credentials: dict) -> int: - """ - Get max chunks for given embedding model - - :param model: model name - :param credentials: model credentials - :return: max chunks - """ - model_schema = self.get_model_schema(model, credentials) - - if model_schema and ModelPropertyKey.MAX_CHUNKS in model_schema.model_properties: - max_chunks: int = model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS] - return max_chunks - - return 1 diff --git a/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py b/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py deleted file mode 100644 index 3967acf07b..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/tokenizers/gpt2_tokenizer.py +++ /dev/null @@ -1,53 +0,0 @@ -import logging -from threading import Lock -from typing import Any - -logger = logging.getLogger(__name__) - -_tokenizer: Any | None = None -_lock = Lock() - - -class GPT2Tokenizer: - @staticmethod - def _get_num_tokens_by_gpt2(text: str) -> int: - """ - use gpt2 tokenizer to get num tokens - """ - _tokenizer = GPT2Tokenizer.get_encoder() - tokens = _tokenizer.encode(text) # type: ignore - return len(tokens) - - @staticmethod - def get_num_tokens(text: str) -> int: - # Because this process needs more cpu resource, we turn this back before we find a better way to handle it. - # - # future = _executor.submit(GPT2Tokenizer._get_num_tokens_by_gpt2, text) - # result = future.result() - # return cast(int, result) - return GPT2Tokenizer._get_num_tokens_by_gpt2(text) - - @staticmethod - def get_encoder(): - global _tokenizer, _lock - if _tokenizer is not None: - return _tokenizer - with _lock: - if _tokenizer is None: - # Try to use tiktoken to get the tokenizer because it is faster - # - try: - import tiktoken - - _tokenizer = tiktoken.get_encoding("gpt2") - except Exception: - from os.path import abspath, dirname, join - - from transformers import GPT2Tokenizer as TransformerGPT2Tokenizer - - base_path = abspath(__file__) - gpt2_tokenizer_path = join(dirname(base_path), "gpt2") - _tokenizer = TransformerGPT2Tokenizer.from_pretrained(gpt2_tokenizer_path) - logger.info("Fallback to Transformers' GPT-2 tokenizer from tiktoken") - - return _tokenizer diff --git a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py b/api/dify_graph/model_runtime/model_providers/__base/tts_model.py deleted file mode 100644 index 0656529f22..0000000000 --- a/api/dify_graph/model_runtime/model_providers/__base/tts_model.py +++ /dev/null @@ -1,79 +0,0 @@ -import logging -from collections.abc import Iterable - -from pydantic import ConfigDict - -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel - -logger = logging.getLogger(__name__) - - -class TTSModel(AIModel): - """ - Model class for TTS model. - """ - - model_type: ModelType = ModelType.TTS - - # pydantic configs - model_config = ConfigDict(protected_namespaces=()) - - def invoke( - self, - model: str, - tenant_id: str, - credentials: dict, - content_text: str, - voice: str, - user: str | None = None, - ) -> Iterable[bytes]: - """ - Invoke large language model - - :param model: model name - :param tenant_id: user tenant id - :param credentials: model credentials - :param voice: model timbre - :param content_text: text content to be translated - :param user: unique user id - :return: translated audio file - """ - try: - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.invoke_tts( - tenant_id=self.tenant_id, - user_id=user or "unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - content_text=content_text, - voice=voice, - ) - except Exception as e: - raise self._transform_invoke_error(e) - - def get_tts_model_voices(self, model: str, credentials: dict, language: str | None = None): - """ - Retrieves the list of voices supported by a given text-to-speech (TTS) model. - - :param language: The language for which the voices are requested. - :param model: The name of the TTS model. - :param credentials: The credentials required to access the TTS model. - :return: A list of voices supported by the TTS model. - """ - from core.plugin.impl.model import PluginModelClient - - plugin_model_manager = PluginModelClient() - return plugin_model_manager.get_tts_model_voices( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=self.plugin_id, - provider=self.provider_name, - model=model, - credentials=credentials, - language=language, - ) diff --git a/api/dify_graph/model_runtime/model_providers/_position.yaml b/api/dify_graph/model_runtime/model_providers/_position.yaml deleted file mode 100644 index fb02de3a67..0000000000 --- a/api/dify_graph/model_runtime/model_providers/_position.yaml +++ /dev/null @@ -1,43 +0,0 @@ -- openai -- deepseek -- anthropic -- azure_openai -- google -- vertex_ai -- nvidia -- nvidia_nim -- cohere -- upstage -- bedrock -- togetherai -- openrouter -- ollama -- mistralai -- groq -- replicate -- huggingface_hub -- xinference -- triton_inference_server -- zhipuai -- baichuan -- spark -- minimax -- tongyi -- wenxin -- moonshot -- tencent -- jina -- chatglm -- yi -- openllm -- localai -- volcengine_maas -- openai_api_compatible -- hunyuan -- siliconflow -- perfxcloud -- zhinao -- fireworks -- mixedbread -- nomic -- voyage diff --git a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py b/api/dify_graph/model_runtime/model_providers/model_provider_factory.py deleted file mode 100644 index de0677a348..0000000000 --- a/api/dify_graph/model_runtime/model_providers/model_provider_factory.py +++ /dev/null @@ -1,387 +0,0 @@ -from __future__ import annotations - -import hashlib -import logging -from collections.abc import Sequence -from threading import Lock - -from pydantic import ValidationError -from redis import RedisError - -import contexts -from configs import dify_config -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType -from dify_graph.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity -from dify_graph.model_runtime.model_providers.__base.ai_model import AIModel -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.model_providers.__base.moderation_model import ModerationModel -from dify_graph.model_runtime.model_providers.__base.rerank_model import RerankModel -from dify_graph.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel -from dify_graph.model_runtime.model_providers.__base.tts_model import TTSModel -from dify_graph.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator -from dify_graph.model_runtime.schema_validators.provider_credential_schema_validator import ( - ProviderCredentialSchemaValidator, -) -from extensions.ext_redis import redis_client -from models.provider_ids import ModelProviderID - -logger = logging.getLogger(__name__) - - -class ModelProviderFactory: - def __init__(self, tenant_id: str): - from core.plugin.impl.model import PluginModelClient - - self.tenant_id = tenant_id - self.plugin_model_manager = PluginModelClient() - - def get_providers(self) -> Sequence[ProviderEntity]: - """ - Get all providers - :return: list of providers - """ - # FIXME(-LAN-): Removed position map sorting since providers are fetched from plugin server - # The plugin server should return providers in the desired order - plugin_providers = self.get_plugin_model_providers() - return [provider.declaration for provider in plugin_providers] - - def get_plugin_model_providers(self) -> Sequence[PluginModelProviderEntity]: - """ - Get all plugin model providers - :return: list of plugin model providers - """ - # check if context is set - try: - contexts.plugin_model_providers.get() - except LookupError: - contexts.plugin_model_providers.set(None) - contexts.plugin_model_providers_lock.set(Lock()) - - with contexts.plugin_model_providers_lock.get(): - plugin_model_providers = contexts.plugin_model_providers.get() - if plugin_model_providers is not None: - return plugin_model_providers - - plugin_model_providers = [] - contexts.plugin_model_providers.set(plugin_model_providers) - - # Fetch plugin model providers - plugin_providers = self.plugin_model_manager.fetch_model_providers(self.tenant_id) - - for provider in plugin_providers: - provider.declaration.provider = provider.plugin_id + "/" + provider.declaration.provider - plugin_model_providers.append(provider) - - return plugin_model_providers - - def get_provider_schema(self, provider: str) -> ProviderEntity: - """ - Get provider schema - :param provider: provider name - :return: provider schema - """ - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - return plugin_model_provider_entity.declaration - - def get_plugin_model_provider(self, provider: str) -> PluginModelProviderEntity: - """ - Get plugin model provider - :param provider: provider name - :return: provider schema - """ - if "/" not in provider: - provider = str(ModelProviderID(provider)) - - # fetch plugin model providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # get the provider - plugin_model_provider_entity = next( - (p for p in plugin_model_provider_entities if p.declaration.provider == provider), - None, - ) - - if not plugin_model_provider_entity: - raise ValueError(f"Invalid provider: {provider}") - - return plugin_model_provider_entity - - def provider_credentials_validate(self, *, provider: str, credentials: dict): - """ - Validate provider credentials - - :param provider: provider name - :param credentials: provider credentials, credentials form defined in `provider_credential_schema`. - :return: - """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - - # get provider_credential_schema and validate credentials according to the rules - provider_credential_schema = plugin_model_provider_entity.declaration.provider_credential_schema - if not provider_credential_schema: - raise ValueError(f"Provider {provider} does not have provider_credential_schema") - - # validate provider credential schema - validator = ProviderCredentialSchemaValidator(provider_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - # validate the credentials, raise exception if validation failed - self.plugin_model_manager.validate_provider_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def model_credentials_validate(self, *, provider: str, model_type: ModelType, model: str, credentials: dict): - """ - Validate model credentials - - :param provider: provider name - :param model_type: model type - :param model: model name - :param credentials: model credentials, credentials form defined in `model_credential_schema`. - :return: - """ - # fetch plugin model provider - plugin_model_provider_entity = self.get_plugin_model_provider(provider=provider) - - # get model_credential_schema and validate credentials according to the rules - model_credential_schema = plugin_model_provider_entity.declaration.model_credential_schema - if not model_credential_schema: - raise ValueError(f"Provider {provider} does not have model_credential_schema") - - # validate model credential schema - validator = ModelCredentialSchemaValidator(model_type, model_credential_schema) - filtered_credentials = validator.validate_and_filter(credentials) - - # call validate_credentials method of model type to validate credentials, raise exception if validation failed - self.plugin_model_manager.validate_model_credentials( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_model_provider_entity.plugin_id, - provider=plugin_model_provider_entity.provider, - model_type=model_type.value, - model=model, - credentials=filtered_credentials, - ) - - return filtered_credentials - - def get_model_schema( - self, *, provider: str, model_type: ModelType, model: str, credentials: dict | None - ) -> AIModelEntity | None: - """ - Get model schema - """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - cache_key = f"{self.tenant_id}:{plugin_id}:{provider_name}:{model_type.value}:{model}" - sorted_credentials = sorted(credentials.items()) if credentials else [] - cache_key += ":".join([hashlib.md5(f"{k}:{v}".encode()).hexdigest() for k, v in sorted_credentials]) - - cached_schema_json = None - try: - cached_schema_json = redis_client.get(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to read plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - if cached_schema_json: - try: - return AIModelEntity.model_validate_json(cached_schema_json) - except ValidationError: - logger.warning( - "Failed to validate cached plugin model schema for model %s", - model, - exc_info=True, - ) - try: - redis_client.delete(cache_key) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to delete invalid plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - schema = self.plugin_model_manager.get_model_schema( - tenant_id=self.tenant_id, - user_id="unknown", - plugin_id=plugin_id, - provider=provider_name, - model_type=model_type.value, - model=model, - credentials=credentials or {}, - ) - - if schema: - try: - redis_client.setex(cache_key, dify_config.PLUGIN_MODEL_SCHEMA_CACHE_TTL, schema.model_dump_json()) - except (RedisError, RuntimeError) as exc: - logger.warning( - "Failed to write plugin model schema cache for model %s: %s", - model, - str(exc), - exc_info=True, - ) - - return schema - - def get_models( - self, - *, - provider: str | None = None, - model_type: ModelType | None = None, - provider_configs: list[ProviderConfig] | None = None, - ) -> list[SimpleProviderEntity]: - """ - Get all models for given model type - - :param provider: provider name - :param model_type: model type - :param provider_configs: list of provider configs - :return: list of models - """ - provider_configs = provider_configs or [] - - # scan all providers - plugin_model_provider_entities = self.get_plugin_model_providers() - - # traverse all model_provider_extensions - providers = [] - for plugin_model_provider_entity in plugin_model_provider_entities: - # filter by provider if provider is present - if provider and plugin_model_provider_entity.declaration.provider != provider: - continue - - # get provider schema - provider_schema = plugin_model_provider_entity.declaration - - model_types = provider_schema.supported_model_types - if model_type: - if model_type not in model_types: - continue - - model_types = [model_type] - - all_model_type_models = [] - for model_schema in provider_schema.models: - if model_schema.model_type != model_type: - continue - - all_model_type_models.append(model_schema) - - simple_provider_schema = provider_schema.to_simple_provider() - if model_type: - simple_provider_schema.models = all_model_type_models - - providers.append(simple_provider_schema) - - return providers - - def get_model_type_instance(self, provider: str, model_type: ModelType) -> AIModel: - """ - Get model type instance by provider name and model type - :param provider: provider name - :param model_type: model type - :return: model type instance - """ - plugin_id, provider_name = self.get_plugin_id_and_provider_name_from_provider(provider) - init_params = { - "tenant_id": self.tenant_id, - "plugin_id": plugin_id, - "provider_name": provider_name, - "plugin_model_provider": self.get_plugin_model_provider(provider), - } - - if model_type == ModelType.LLM: - return LargeLanguageModel.model_validate(init_params) - elif model_type == ModelType.TEXT_EMBEDDING: - return TextEmbeddingModel.model_validate(init_params) - elif model_type == ModelType.RERANK: - return RerankModel.model_validate(init_params) - elif model_type == ModelType.SPEECH2TEXT: - return Speech2TextModel.model_validate(init_params) - elif model_type == ModelType.MODERATION: - return ModerationModel.model_validate(init_params) - elif model_type == ModelType.TTS: - return TTSModel.model_validate(init_params) - - raise ValueError(f"Unsupported model type: {model_type}") - - def get_provider_icon(self, provider: str, icon_type: str, lang: str) -> tuple[bytes, str]: - """ - Get provider icon - :param provider: provider name - :param icon_type: icon type (icon_small or icon_small_dark) - :param lang: language (zh_Hans or en_US) - :return: provider icon - """ - # get the provider schema - provider_schema = self.get_provider_schema(provider) - - if icon_type.lower() == "icon_small": - if not provider_schema.icon_small: - raise ValueError(f"Provider {provider} does not have small icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small.zh_Hans - else: - file_name = provider_schema.icon_small.en_US - elif icon_type.lower() == "icon_small_dark": - if not provider_schema.icon_small_dark: - raise ValueError(f"Provider {provider} does not have small dark icon.") - - if lang.lower() == "zh_hans": - file_name = provider_schema.icon_small_dark.zh_Hans - else: - file_name = provider_schema.icon_small_dark.en_US - else: - raise ValueError(f"Unsupported icon type: {icon_type}.") - - if not file_name: - raise ValueError(f"Provider {provider} does not have icon.") - - image_mime_types = { - "jpg": "image/jpeg", - "jpeg": "image/jpeg", - "png": "image/png", - "gif": "image/gif", - "bmp": "image/bmp", - "tiff": "image/tiff", - "tif": "image/tiff", - "webp": "image/webp", - "svg": "image/svg+xml", - "ico": "image/vnd.microsoft.icon", - "heif": "image/heif", - "heic": "image/heic", - } - - extension = file_name.split(".")[-1] - mime_type = image_mime_types.get(extension, "image/png") - - # get icon bytes from plugin asset manager - from core.plugin.impl.asset import PluginAssetManager - - plugin_asset_manager = PluginAssetManager() - return plugin_asset_manager.fetch_asset(tenant_id=self.tenant_id, id=file_name), mime_type - - def get_plugin_id_and_provider_name_from_provider(self, provider: str) -> tuple[str, str]: - """ - Get plugin id and provider name from provider name - :param provider: provider name - :return: plugin id and provider name - """ - - provider_id = ModelProviderID(provider) - return provider_id.plugin_id, provider_id.provider_name diff --git a/api/dify_graph/model_runtime/schema_validators/__init__.py b/api/dify_graph/model_runtime/schema_validators/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/model_runtime/schema_validators/common_validator.py b/api/dify_graph/model_runtime/schema_validators/common_validator.py deleted file mode 100644 index 04cdb8e4f7..0000000000 --- a/api/dify_graph/model_runtime/schema_validators/common_validator.py +++ /dev/null @@ -1,92 +0,0 @@ -from typing import Union, cast - -from dify_graph.model_runtime.entities.provider_entities import CredentialFormSchema, FormType - - -class CommonValidator: - def _validate_and_filter_credential_form_schemas( - self, credential_form_schemas: list[CredentialFormSchema], credentials: dict - ): - need_validate_credential_form_schema_map = {} - for credential_form_schema in credential_form_schemas: - if not credential_form_schema.show_on: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - continue - - all_show_on_match = True - for show_on_object in credential_form_schema.show_on: - if show_on_object.variable not in credentials: - all_show_on_match = False - break - - if credentials[show_on_object.variable] != show_on_object.value: - all_show_on_match = False - break - - if all_show_on_match: - need_validate_credential_form_schema_map[credential_form_schema.variable] = credential_form_schema - - # Iterate over the remaining credential_form_schemas, verify each credential_form_schema - validated_credentials = {} - for credential_form_schema in need_validate_credential_form_schema_map.values(): - # add the value of the credential_form_schema corresponding to it to validated_credentials - result = self._validate_credential_form_schema(credential_form_schema, credentials) - if result: - validated_credentials[credential_form_schema.variable] = result - - return validated_credentials - - def _validate_credential_form_schema( - self, credential_form_schema: CredentialFormSchema, credentials: dict - ) -> Union[str, bool, None]: - """ - Validate credential form schema - - :param credential_form_schema: credential form schema - :param credentials: credentials - :return: validated credential form schema value - """ - # If the variable does not exist in credentials - value: Union[str, bool, None] = None - if credential_form_schema.variable not in credentials or not credentials[credential_form_schema.variable]: - # If required is True, an exception is thrown - if credential_form_schema.required: - raise ValueError(f"Variable {credential_form_schema.variable} is required") - else: - # Get the value of default - if credential_form_schema.default: - # If it exists, add it to validated_credentials - return credential_form_schema.default - else: - # If default does not exist, skip - return None - - # Get the value corresponding to the variable from credentials - value = cast(str, credentials[credential_form_schema.variable]) - - # If max_length=0, no validation is performed - if credential_form_schema.max_length: - if len(value) > credential_form_schema.max_length: - raise ValueError( - f"Variable {credential_form_schema.variable} length should not be" - f" greater than {credential_form_schema.max_length}" - ) - - # check the type of value - if not isinstance(value, str): - raise ValueError(f"Variable {credential_form_schema.variable} should be string") - - if credential_form_schema.type in {FormType.SELECT, FormType.RADIO}: - # If the value is in options, no validation is performed - if credential_form_schema.options: - if value not in [option.value for option in credential_form_schema.options]: - raise ValueError(f"Variable {credential_form_schema.variable} is not in options") - - if credential_form_schema.type == FormType.SWITCH: - # If the value is not in ['true', 'false'], an exception is thrown - if value.lower() not in {"true", "false"}: - raise ValueError(f"Variable {credential_form_schema.variable} should be true or false") - - value = value.lower() == "true" - - return value diff --git a/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py b/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py deleted file mode 100644 index a97796e98f..0000000000 --- a/api/dify_graph/model_runtime/schema_validators/model_credential_schema_validator.py +++ /dev/null @@ -1,27 +0,0 @@ -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ModelCredentialSchema -from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator - - -class ModelCredentialSchemaValidator(CommonValidator): - def __init__(self, model_type: ModelType, model_credential_schema: ModelCredentialSchema): - self.model_type = model_type - self.model_credential_schema = model_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate model credentials - - :param credentials: model credentials - :return: filtered credentials - """ - - if self.model_credential_schema is None: - raise ValueError("Model credential schema is None") - - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.model_credential_schema.credential_form_schemas - - credentials["__model_type"] = self.model_type.value - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py b/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py deleted file mode 100644 index 2fed75a76c..0000000000 --- a/api/dify_graph/model_runtime/schema_validators/provider_credential_schema_validator.py +++ /dev/null @@ -1,19 +0,0 @@ -from dify_graph.model_runtime.entities.provider_entities import ProviderCredentialSchema -from dify_graph.model_runtime.schema_validators.common_validator import CommonValidator - - -class ProviderCredentialSchemaValidator(CommonValidator): - def __init__(self, provider_credential_schema: ProviderCredentialSchema): - self.provider_credential_schema = provider_credential_schema - - def validate_and_filter(self, credentials: dict): - """ - Validate provider credentials - - :param credentials: provider credentials - :return: validated provider credentials - """ - # get the credential_form_schemas in provider_credential_schema - credential_form_schemas = self.provider_credential_schema.credential_form_schemas - - return self._validate_and_filter_credential_form_schemas(credential_form_schemas, credentials) diff --git a/api/dify_graph/model_runtime/utils/__init__.py b/api/dify_graph/model_runtime/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/model_runtime/utils/encoders.py b/api/dify_graph/model_runtime/utils/encoders.py deleted file mode 100644 index c85152463e..0000000000 --- a/api/dify_graph/model_runtime/utils/encoders.py +++ /dev/null @@ -1,216 +0,0 @@ -import dataclasses -import datetime -from collections import defaultdict, deque -from collections.abc import Callable -from decimal import Decimal -from enum import Enum -from ipaddress import IPv4Address, IPv4Interface, IPv4Network, IPv6Address, IPv6Interface, IPv6Network -from pathlib import Path, PurePath -from re import Pattern -from types import GeneratorType -from typing import Any, Literal, Union -from uuid import UUID - -from pydantic import BaseModel -from pydantic.networks import AnyUrl, NameEmail -from pydantic.types import SecretBytes, SecretStr -from pydantic_core import Url -from pydantic_extra_types.color import Color - - -def _model_dump(model: BaseModel, mode: Literal["json", "python"] = "json", **kwargs: Any) -> Any: - return model.model_dump(mode=mode, **kwargs) - - -# Taken from Pydantic v1 as is -def isoformat(o: Union[datetime.date, datetime.time]) -> str: - return o.isoformat() - - -# Taken from Pydantic v1 as is -# TODO: pv2 should this return strings instead? -def decimal_encoder(dec_value: Decimal) -> Union[int, float]: - """ - Encodes a Decimal as int of there's no exponent, otherwise float - - This is useful when we use ConstrainedDecimal to represent Numeric(x,0) - where a integer (but not int typed) is used. Encoding this as a float - results in failed round-tripping between encode and parse. - Our Id type is a prime example of this. - - >>> decimal_encoder(Decimal("1.0")) - 1.0 - - >>> decimal_encoder(Decimal("1")) - 1 - """ - if dec_value.as_tuple().exponent >= 0: # type: ignore[operator] - return int(dec_value) - else: - return float(dec_value) - - -ENCODERS_BY_TYPE: dict[type[Any], Callable[[Any], Any]] = { - bytes: lambda o: o.decode(), - Color: str, - datetime.date: isoformat, - datetime.datetime: isoformat, - datetime.time: isoformat, - datetime.timedelta: lambda td: td.total_seconds(), - Decimal: decimal_encoder, - Enum: lambda o: o.value, - frozenset: list, - deque: list, - GeneratorType: list, - IPv4Address: str, - IPv4Interface: str, - IPv4Network: str, - IPv6Address: str, - IPv6Interface: str, - IPv6Network: str, - NameEmail: str, - Path: str, - Pattern: lambda o: o.pattern, - SecretBytes: str, - SecretStr: str, - set: list, - UUID: str, - Url: str, - AnyUrl: str, -} - - -def generate_encoders_by_class_tuples( - type_encoder_map: dict[Any, Callable[[Any], Any]], -) -> dict[Callable[[Any], Any], tuple[Any, ...]]: - encoders_by_class_tuples: dict[Callable[[Any], Any], tuple[Any, ...]] = defaultdict(tuple) - for type_, encoder in type_encoder_map.items(): - encoders_by_class_tuples[encoder] += (type_,) - return encoders_by_class_tuples - - -encoders_by_class_tuples = generate_encoders_by_class_tuples(ENCODERS_BY_TYPE) - - -def jsonable_encoder( - obj: Any, - by_alias: bool = True, - exclude_unset: bool = False, - exclude_defaults: bool = False, - exclude_none: bool = False, - custom_encoder: dict[Any, Callable[[Any], Any]] | None = None, - sqlalchemy_safe: bool = True, -) -> Any: - custom_encoder = custom_encoder or {} - if custom_encoder: - if type(obj) in custom_encoder: - return custom_encoder[type(obj)](obj) - else: - for encoder_type, encoder_instance in custom_encoder.items(): - if isinstance(obj, encoder_type): - return encoder_instance(obj) - if isinstance(obj, BaseModel): - obj_dict = _model_dump( - obj, - mode="json", - include=None, - exclude=None, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - ) - if "__root__" in obj_dict: - obj_dict = obj_dict["__root__"] - return jsonable_encoder( - obj_dict, - exclude_none=exclude_none, - exclude_defaults=exclude_defaults, - sqlalchemy_safe=sqlalchemy_safe, - ) - if dataclasses.is_dataclass(obj): - # Ensure obj is a dataclass instance, not a dataclass type - if not isinstance(obj, type): - obj_dict = dataclasses.asdict(obj) - return jsonable_encoder( - obj_dict, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - if isinstance(obj, Enum): - return obj.value - if isinstance(obj, PurePath): - return str(obj) - if isinstance(obj, str | int | float | type(None)): - return obj - if isinstance(obj, Decimal): - return format(obj, "f") - if isinstance(obj, dict): - encoded_dict = {} - for key, value in obj.items(): - if (not sqlalchemy_safe or (not isinstance(key, str)) or (not key.startswith("_sa"))) and ( - value is not None or not exclude_none - ): - encoded_key = jsonable_encoder( - key, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_value = jsonable_encoder( - value, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - encoded_dict[encoded_key] = encoded_value - return encoded_dict - if isinstance(obj, list | set | frozenset | GeneratorType | tuple | deque): - encoded_list = [] - for item in obj: - encoded_list.append( - jsonable_encoder( - item, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) - ) - return encoded_list - - if type(obj) in ENCODERS_BY_TYPE: - return ENCODERS_BY_TYPE[type(obj)](obj) - for encoder, classes_tuple in encoders_by_class_tuples.items(): - if isinstance(obj, classes_tuple): - return encoder(obj) - - try: - data = dict(obj) # type: ignore - except Exception as e: - errors: list[Exception] = [] - errors.append(e) - try: - data = vars(obj) # type: ignore - except Exception as e: - errors.append(e) - raise ValueError(str(errors)) from e - return jsonable_encoder( - data, - by_alias=by_alias, - exclude_unset=exclude_unset, - exclude_defaults=exclude_defaults, - exclude_none=exclude_none, - custom_encoder=custom_encoder, - sqlalchemy_safe=sqlalchemy_safe, - ) diff --git a/api/dify_graph/node_events/__init__.py b/api/dify_graph/node_events/__init__.py deleted file mode 100644 index a9bef8f9a2..0000000000 --- a/api/dify_graph/node_events/__init__.py +++ /dev/null @@ -1,46 +0,0 @@ -from .agent import AgentLogEvent -from .base import NodeEventBase, NodeRunResult -from .iteration import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, -) -from .loop import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, -) -from .node import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - ModelInvokeCompletedEvent, - PauseRequestedEvent, - RunRetrieverResourceEvent, - RunRetryEvent, - StreamChunkEvent, - StreamCompletedEvent, -) - -__all__ = [ - "AgentLogEvent", - "HumanInputFormFilledEvent", - "HumanInputFormTimeoutEvent", - "IterationFailedEvent", - "IterationNextEvent", - "IterationStartedEvent", - "IterationSucceededEvent", - "LoopFailedEvent", - "LoopNextEvent", - "LoopStartedEvent", - "LoopSucceededEvent", - "ModelInvokeCompletedEvent", - "NodeEventBase", - "NodeRunResult", - "PauseRequestedEvent", - "RunRetrieverResourceEvent", - "RunRetryEvent", - "StreamChunkEvent", - "StreamCompletedEvent", -] diff --git a/api/dify_graph/node_events/agent.py b/api/dify_graph/node_events/agent.py deleted file mode 100644 index bf295ec774..0000000000 --- a/api/dify_graph/node_events/agent.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class AgentLogEvent(NodeEventBase): - message_id: str = Field(..., description="id") - label: str = Field(..., description="label") - node_execution_id: str = Field(..., description="node execution id") - parent_id: str | None = Field(..., description="parent id") - error: str | None = Field(..., description="error") - status: str = Field(..., description="status") - data: Mapping[str, Any] = Field(..., description="data") - metadata: Mapping[str, Any] = Field(default_factory=dict, description="metadata") - node_id: str = Field(..., description="node id") diff --git a/api/dify_graph/node_events/base.py b/api/dify_graph/node_events/base.py deleted file mode 100644 index 2f6259ae7d..0000000000 --- a/api/dify_graph/node_events/base.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping -from typing import Any - -from pydantic import BaseModel, Field - -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage - - -class NodeEventBase(BaseModel): - """Base class for all node events""" - - pass - - -def _default_metadata(): - v: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {} - return v - - -class NodeRunResult(BaseModel): - """ - Node Run Result. - """ - - status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.PENDING - - inputs: Mapping[str, Any] = Field(default_factory=dict) - process_data: Mapping[str, Any] = Field(default_factory=dict) - outputs: Mapping[str, Any] = Field(default_factory=dict) - metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = Field(default_factory=_default_metadata) - llm_usage: LLMUsage = Field(default_factory=LLMUsage.empty_usage) - - edge_source_handle: str = "source" # source handle id of node with multiple branches - - error: str = "" - error_type: str = "" - - # single step node run retry - retry_index: int = 0 diff --git a/api/dify_graph/node_events/iteration.py b/api/dify_graph/node_events/iteration.py deleted file mode 100644 index 744ddea628..0000000000 --- a/api/dify_graph/node_events/iteration.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class IterationStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class IterationNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_iteration_output: Any = None - - -class IterationSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class IterationFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/dify_graph/node_events/loop.py b/api/dify_graph/node_events/loop.py deleted file mode 100644 index 3ae230f9f6..0000000000 --- a/api/dify_graph/node_events/loop.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Mapping -from datetime import datetime -from typing import Any - -from pydantic import Field - -from .base import NodeEventBase - - -class LoopStartedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - predecessor_node_id: str | None = None - - -class LoopNextEvent(NodeEventBase): - index: int = Field(..., description="index") - pre_loop_output: Any = None - - -class LoopSucceededEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - - -class LoopFailedEvent(NodeEventBase): - start_at: datetime = Field(..., description="start at") - inputs: Mapping[str, object] = Field(default_factory=dict) - outputs: Mapping[str, object] = Field(default_factory=dict) - metadata: Mapping[str, object] = Field(default_factory=dict) - steps: int = 0 - error: str = Field(..., description="failed reason") diff --git a/api/dify_graph/node_events/node.py b/api/dify_graph/node_events/node.py deleted file mode 100644 index 2e3973b8fa..0000000000 --- a/api/dify_graph/node_events/node.py +++ /dev/null @@ -1,65 +0,0 @@ -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any - -from pydantic import Field - -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.file import File -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult - -from .base import NodeEventBase - - -class RunRetrieverResourceEvent(NodeEventBase): - retriever_resources: Sequence[Mapping[str, Any]] = Field(..., description="retriever resources") - context: str = Field(..., description="context") - context_files: list[File] | None = Field(default=None, description="context files") - - -class ModelInvokeCompletedEvent(NodeEventBase): - text: str - usage: LLMUsage - finish_reason: str | None = None - reasoning_content: str | None = None - structured_output: dict | None = None - - -class RunRetryEvent(NodeEventBase): - error: str = Field(..., description="error") - retry_index: int = Field(..., description="Retry attempt number") - start_at: datetime = Field(..., description="Retry start time") - - -class StreamChunkEvent(NodeEventBase): - # Spec-compliant fields - selector: Sequence[str] = Field( - ..., description="selector identifying the output location (e.g., ['nodeA', 'text'])" - ) - chunk: str = Field(..., description="the actual chunk content") - is_final: bool = Field(default=False, description="indicates if this is the last chunk") - - -class StreamCompletedEvent(NodeEventBase): - node_run_result: NodeRunResult = Field(..., description="run result") - - -class PauseRequestedEvent(NodeEventBase): - reason: PauseReason = Field(..., description="pause reason") - - -class HumanInputFormFilledEvent(NodeEventBase): - """Event emitted when a human input form is submitted.""" - - node_title: str - rendered_content: str - action_id: str - action_text: str - - -class HumanInputFormTimeoutEvent(NodeEventBase): - """Event emitted when a human input form times out.""" - - node_title: str - expiration_time: datetime diff --git a/api/dify_graph/nodes/__init__.py b/api/dify_graph/nodes/__init__.py deleted file mode 100644 index 0223149bb8..0000000000 --- a/api/dify_graph/nodes/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes - -__all__ = ["BuiltinNodeTypes"] diff --git a/api/dify_graph/nodes/answer/__init__.py b/api/dify_graph/nodes/answer/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/nodes/answer/answer_node.py b/api/dify_graph/nodes/answer/answer_node.py deleted file mode 100644 index 4286e1a492..0000000000 --- a/api/dify_graph/nodes/answer/answer_node.py +++ /dev/null @@ -1,70 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.answer.entities import AnswerNodeData -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.variables import ArrayFileSegment, FileSegment, Segment - - -class AnswerNode(Node[AnswerNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - segments = self.graph_runtime_state.variable_pool.convert_template(self.node_data.answer) - files = self._extract_files_from_segments(segments.value) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"answer": segments.markdown, "files": ArrayFileSegment(value=files)}, - ) - - def _extract_files_from_segments(self, segments: Sequence[Segment]): - """Extract all files from segments containing FileSegment or ArrayFileSegment instances. - - FileSegment contains a single file, while ArrayFileSegment contains multiple files. - This method flattens all files into a single list. - """ - files = [] - for segment in segments: - if isinstance(segment, FileSegment): - # Single file - wrap in list for consistency - files.append(segment.value) - elif isinstance(segment, ArrayFileSegment): - # Multiple files - extend the list - files.extend(segment.value) - return files - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: AnswerNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_template_parser = VariableTemplateParser(template=node_data.answer) - variable_selectors = variable_template_parser.extract_variable_selectors() - - variable_mapping = {} - for variable_selector in variable_selectors: - variable_mapping[node_id + "." + variable_selector.variable] = variable_selector.value_selector - - return variable_mapping - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this Answer node - """ - return Template.from_answer_template(self.node_data.answer) diff --git a/api/dify_graph/nodes/answer/entities.py b/api/dify_graph/nodes/answer/entities.py deleted file mode 100644 index cd82df1ac4..0000000000 --- a/api/dify_graph/nodes/answer/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum, auto - -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class AnswerNodeData(BaseNodeData): - """ - Answer Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ANSWER - answer: str = Field(..., description="answer template string") - - -class GenerateRouteChunk(BaseModel): - """ - Generate Route Chunk. - """ - - class ChunkType(StrEnum): - VAR = auto() - TEXT = auto() - - type: ChunkType = Field(..., description="generate route chunk type") - - -class VarGenerateRouteChunk(GenerateRouteChunk): - """ - Var Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.VAR - """generate route chunk type""" - value_selector: Sequence[str] = Field(..., description="value selector") - - -class TextGenerateRouteChunk(GenerateRouteChunk): - """ - Text Generate Route Chunk. - """ - - type: GenerateRouteChunk.ChunkType = GenerateRouteChunk.ChunkType.TEXT - """generate route chunk type""" - text: str = Field(..., description="text") - - -class AnswerNodeDoubleLink(BaseModel): - node_id: str = Field(..., description="node id") - source_node_ids: list[str] = Field(..., description="source node ids") - target_node_ids: list[str] = Field(..., description="target node ids") - - -class AnswerStreamGenerateRoute(BaseModel): - """ - AnswerStreamGenerateRoute entity - """ - - answer_dependencies: dict[str, list[str]] = Field( - ..., description="answer dependencies (answer node id -> dependent answer node ids)" - ) - answer_generate_route: dict[str, list[GenerateRouteChunk]] = Field( - ..., description="answer generate route (answer node id -> generate route chunks)" - ) diff --git a/api/dify_graph/nodes/base/__init__.py b/api/dify_graph/nodes/base/__init__.py deleted file mode 100644 index 036e25895d..0000000000 --- a/api/dify_graph/nodes/base/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState -from .usage_tracking_mixin import LLMUsageTrackingMixin - -__all__ = [ - "BaseIterationNodeData", - "BaseIterationState", - "BaseLoopNodeData", - "BaseLoopState", - "LLMUsageTrackingMixin", -] diff --git a/api/dify_graph/nodes/base/entities.py b/api/dify_graph/nodes/base/entities.py deleted file mode 100644 index 4f8b2682e1..0000000000 --- a/api/dify_graph/nodes/base/entities.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from pydantic import BaseModel, field_validator - -from dify_graph.entities.base_node_data import BaseNodeData - - -class VariableSelector(BaseModel): - """ - Variable Selector. - """ - - variable: str - value_selector: Sequence[str] - - -class OutputVariableType(StrEnum): - STRING = "string" - NUMBER = "number" - INTEGER = "integer" - SECRET = "secret" - BOOLEAN = "boolean" - OBJECT = "object" - FILE = "file" - ARRAY = "array" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_BOOLEAN = "array[boolean]" - ARRAY_FILE = "array[file]" - ANY = "any" - ARRAY_ANY = "array[any]" - - -class OutputVariableEntity(BaseModel): - """ - Output Variable Entity. - """ - - variable: str - value_type: OutputVariableType = OutputVariableType.ANY - value_selector: Sequence[str] - - @field_validator("value_type", mode="before") - @classmethod - def normalize_value_type(cls, v: Any) -> Any: - """ - Normalize value_type to handle case-insensitive array types. - Converts 'Array[...]' to 'array[...]' for backward compatibility. - """ - if isinstance(v, str) and v.startswith("Array["): - return v.lower() - return v - - -class BaseIterationNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseIterationState(BaseModel): - iteration_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData - - -class BaseLoopNodeData(BaseNodeData): - start_node_id: str | None = None - - -class BaseLoopState(BaseModel): - loop_node_id: str - index: int - inputs: dict - - class MetaData(BaseModel): - pass - - metadata: MetaData diff --git a/api/dify_graph/nodes/base/node.py b/api/dify_graph/nodes/base/node.py deleted file mode 100644 index 56b46a5894..0000000000 --- a/api/dify_graph/nodes/base/node.py +++ /dev/null @@ -1,808 +0,0 @@ -from __future__ import annotations - -import logging -import operator -from abc import abstractmethod -from collections.abc import Generator, Mapping, Sequence -from functools import singledispatchmethod -from types import MappingProxyType -from typing import Any, ClassVar, Generic, Protocol, TypeVar, cast, get_args, get_origin -from uuid import uuid4 - -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import ( - ErrorStrategy, - NodeExecutionType, - NodeState, - NodeType, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunAgentLogEvent, - NodeRunFailedEvent, - NodeRunHumanInputFormFilledEvent, - NodeRunHumanInputFormTimeoutEvent, - NodeRunIterationFailedEvent, - NodeRunIterationNextEvent, - NodeRunIterationStartedEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunRetrieverResourceEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.node_events import ( - AgentLogEvent, - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - PauseRequestedEvent, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, -) -from dify_graph.runtime import GraphRuntimeState -from libs.datetime_utils import naive_utc_now - -NodeDataT = TypeVar("NodeDataT", bound=BaseNodeData) -_MISSING_RUN_CONTEXT_VALUE = object() - -logger = logging.getLogger(__name__) - - -class DifyRunContextProtocol(Protocol): - tenant_id: str - app_id: str - user_id: str - user_from: Any - invoke_from: Any - - -class _MappingDifyRunContext: - def __init__(self, mapping: Mapping[str, Any]) -> None: - self.tenant_id = str(mapping["tenant_id"]) - self.app_id = str(mapping["app_id"]) - self.user_id = str(mapping["user_id"]) - self.user_from = mapping["user_from"] - self.invoke_from = mapping["invoke_from"] - - -class Node(Generic[NodeDataT]): - """BaseNode serves as the foundational class for all node implementations. - - Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output` - attribute to track files generated by the LLM). However, these states are not persisted - when the workflow is suspended or resumed. If a node needs its state to be preserved - across workflow suspension and resumption, it should include the relevant state data - in its output. - """ - - node_type: ClassVar[NodeType] - execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE - _node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData - - def __init_subclass__(cls, **kwargs: Any) -> None: - """ - Automatically extract and validate the node data type from the generic parameter. - - When a subclass is defined as `class MyNode(Node[MyNodeData])`, this method: - 1. Inspects `__orig_bases__` to find the `Node[T]` parameterization - 2. Extracts `T` (e.g., `MyNodeData`) from the generic argument - 3. Validates that `T` is a proper `BaseNodeData` subclass - 4. Stores it in `_node_data_type` for automatic hydration in `__init__` - - This eliminates the need for subclasses to manually implement boilerplate - accessor methods like `_get_title()`, `_get_error_strategy()`, etc. - - How it works: - :: - - class CodeNode(Node[CodeNodeData]): - │ │ - │ └─────────────────────────────────┐ - │ │ - ▼ ▼ - ┌─────────────────────────────┐ ┌─────────────────────────────────┐ - │ __orig_bases__ = ( │ │ CodeNodeData(BaseNodeData) │ - │ Node[CodeNodeData], │ │ title: str │ - │ ) │ │ desc: str | None │ - └──────────────┬──────────────┘ │ ... │ - │ └─────────────────────────────────┘ - ▼ ▲ - ┌─────────────────────────────┐ │ - │ get_origin(base) -> Node │ │ - │ get_args(base) -> ( │ │ - │ CodeNodeData, │ ──────────────────────┘ - │ ) │ - └──────────────┬──────────────┘ - │ - ▼ - ┌─────────────────────────────┐ - │ Validate: │ - │ - Is it a type? │ - │ - Is it a BaseNodeData │ - │ subclass? │ - └──────────────┬──────────────┘ - │ - ▼ - ┌─────────────────────────────┐ - │ cls._node_data_type = │ - │ CodeNodeData │ - └─────────────────────────────┘ - - Later, in __init__: - :: - - config["data"] ──► _node_data_type.model_validate(..., from_attributes=True) - │ - ▼ - CodeNodeData instance - (stored in self._node_data) - - Example: - class CodeNode(Node[CodeNodeData]): # CodeNodeData is auto-extracted - node_type = BuiltinNodeTypes.CODE - # No need to implement _get_title, _get_error_strategy, etc. - """ - super().__init_subclass__(**kwargs) - - if cls is Node: - return - - node_data_type = cls._extract_node_data_type_from_generic() - - if node_data_type is None: - raise TypeError(f"{cls.__name__} must inherit from Node[T] with a BaseNodeData subtype") - - cls._node_data_type = node_data_type - - # Skip base class itself - if cls is Node: - return - # Only register production node implementations defined under the - # canonical workflow namespaces. - # This prevents test helper subclasses from polluting the global registry and - # accidentally overriding real node types (e.g., a test Answer node). - module_name = getattr(cls, "__module__", "") - # Only register concrete subclasses that define node_type and version() - node_type = cls.node_type - version = cls.version() - bucket = Node._registry.setdefault(node_type, {}) - if module_name.startswith(("dify_graph.nodes.", "core.workflow.nodes.")): - # Production node definitions take precedence and may override - bucket[version] = cls # type: ignore[index] - else: - # External/test subclasses may register but must not override production - bucket.setdefault(version, cls) # type: ignore[index] - # Maintain a "latest" pointer preferring numeric versions; fallback to lexicographic - version_keys = [v for v in bucket if v != "latest"] - numeric_pairs: list[tuple[str, int]] = [] - for v in version_keys: - numeric_pairs.append((v, int(v))) - if numeric_pairs: - latest_key = max(numeric_pairs, key=operator.itemgetter(1))[0] - else: - latest_key = max(version_keys) if version_keys else version - bucket["latest"] = bucket[latest_key] - Node._registry_version += 1 - - @classmethod - def _extract_node_data_type_from_generic(cls) -> type[BaseNodeData] | None: - """ - Extract the node data type from the generic parameter `Node[T]`. - - Inspects `__orig_bases__` to find the `Node[T]` parameterization and extracts `T`. - - Returns: - The extracted BaseNodeData subtype, or None if not found. - - Raises: - TypeError: If the generic argument is invalid (not exactly one argument, - or not a BaseNodeData subtype). - """ - # __orig_bases__ contains the original generic bases before type erasure. - # For `class CodeNode(Node[CodeNodeData])`, this would be `(Node[CodeNodeData],)`. - for base in getattr(cls, "__orig_bases__", ()): # type: ignore[attr-defined] - origin = get_origin(base) # Returns `Node` for `Node[CodeNodeData]` - if origin is Node: - args = get_args(base) # Returns `(CodeNodeData,)` for `Node[CodeNodeData]` - if len(args) != 1: - raise TypeError(f"{cls.__name__} must specify exactly one node data generic argument") - - candidate = args[0] - if not isinstance(candidate, type) or not issubclass(candidate, BaseNodeData): - raise TypeError(f"{cls.__name__} must parameterize Node with a BaseNodeData subtype") - - return candidate - - return None - - # Global registry populated via __init_subclass__ - _registry: ClassVar[dict[NodeType, dict[str, type[Node]]]] = {} - _registry_version: ClassVar[int] = 0 - - @classmethod - def get_registry_version(cls) -> int: - return cls._registry_version - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - self._graph_init_params = graph_init_params - self._run_context = MappingProxyType(dict(graph_init_params.run_context)) - self.id = id - self.workflow_id = graph_init_params.workflow_id - self.graph_config = graph_init_params.graph_config - self.workflow_call_depth = graph_init_params.call_depth - self.graph_runtime_state = graph_runtime_state - self.state: NodeState = NodeState.UNKNOWN # node execution state - - node_id = config["id"] - - self._node_id = node_id - self._node_execution_id: str = "" - self._start_at = naive_utc_now() - - self._node_data = self.validate_node_data(config["data"]) - - self.post_init() - - @classmethod - def validate_node_data(cls, node_data: BaseNodeData) -> NodeDataT: - """Validate shared graph node payloads against the subclass-declared NodeData model.""" - return cast(NodeDataT, cls._node_data_type.model_validate(node_data, from_attributes=True)) - - def init_node_data(self, data: BaseNodeData | Mapping[str, Any]) -> None: - """Hydrate `_node_data` for legacy callers that bypass `__init__`.""" - self._node_data = self.validate_node_data(cast(BaseNodeData, data)) - - def post_init(self) -> None: - """Optional hook for subclasses requiring extra initialization.""" - return - - @property - def graph_init_params(self) -> GraphInitParams: - return self._graph_init_params - - @property - def run_context(self) -> Mapping[str, Any]: - return self._run_context - - def get_run_context_value(self, key: str, default: Any = None) -> Any: - return self._run_context.get(key, default) - - def require_run_context_value(self, key: str) -> Any: - value = self.get_run_context_value(key, _MISSING_RUN_CONTEXT_VALUE) - if value is _MISSING_RUN_CONTEXT_VALUE: - raise ValueError(f"run_context missing required key: {key}") - return value - - def require_dify_context(self) -> DifyRunContextProtocol: - raw_ctx = self.require_run_context_value(DIFY_RUN_CONTEXT_KEY) - if raw_ctx is None: - raise ValueError(f"run_context missing required key: {DIFY_RUN_CONTEXT_KEY}") - - if isinstance(raw_ctx, Mapping): - missing_keys = [ - key for key in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from") if key not in raw_ctx - ] - if missing_keys: - raise ValueError(f"dify context missing required keys: {', '.join(missing_keys)}") - return _MappingDifyRunContext(raw_ctx) - - for attr in ("tenant_id", "app_id", "user_id", "user_from", "invoke_from"): - if not hasattr(raw_ctx, attr): - raise TypeError(f"invalid dify context object, missing attribute: {attr}") - - return cast(DifyRunContextProtocol, raw_ctx) - - @property - def execution_id(self) -> str: - return self._node_execution_id - - def ensure_execution_id(self) -> str: - if self._node_execution_id: - return self._node_execution_id - - resumed_execution_id = self._restore_execution_id_from_runtime_state() - if resumed_execution_id: - self._node_execution_id = resumed_execution_id - return self._node_execution_id - - self._node_execution_id = str(uuid4()) - return self._node_execution_id - - def _restore_execution_id_from_runtime_state(self) -> str | None: - graph_execution = self.graph_runtime_state.graph_execution - try: - node_executions = graph_execution.node_executions - except AttributeError: - return None - if not isinstance(node_executions, dict): - return None - node_execution = node_executions.get(self._node_id) - if node_execution is None: - return None - execution_id = node_execution.execution_id - if not execution_id: - return None - return str(execution_id) - - @abstractmethod - def _run(self) -> NodeRunResult | Generator[NodeEventBase, None, None]: - """ - Run node - :return: - """ - raise NotImplementedError - - def populate_start_event(self, event: NodeRunStartedEvent) -> None: - """Allow subclasses to enrich the started event without cross-node imports in the base class.""" - _ = event - - def run(self) -> Generator[GraphNodeEventBase, None, None]: - execution_id = self.ensure_execution_id() - self._start_at = naive_utc_now() - - # Create and push start event with required fields - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.title, - in_iteration_id=None, - start_at=self._start_at, - ) - try: - self.populate_start_event(start_event) - except Exception: - logger.warning("Failed to populate start event for node %s", self._node_id, exc_info=True) - yield start_event - - try: - result = self._run() - - # Handle NodeRunResult - if isinstance(result, NodeRunResult): - yield self._convert_node_run_result_to_graph_node_event(result) - return - - # Handle event stream - for event in result: - # NOTE: this is necessary because iteration and loop nodes yield GraphNodeEventBase - if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance] - yield self._dispatch(event) - elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance] - event.id = self.execution_id - yield event - else: - yield event - except Exception as e: - logger.exception("Node %s failed to run", self._node_id) - result = NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - error_type="WorkflowNodeError", - ) - finished_at = naive_utc_now() - yield NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=str(e), - ) - - @classmethod - def extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - config: NodeConfigDict, - ) -> Mapping[str, Sequence[str]]: - """Extracts references variable selectors from node configuration. - - The `config` parameter represents the configuration for a specific node type and corresponds - to the `data` field in the node definition object. - - The returned mapping has the following structure: - - {'1747829548239.#1747829667553.result#': ['1747829667553', 'result']} - - For loop and iteration nodes, the mapping may look like this: - - { - "1748332301644.input_selector": ["1748332363630", "result"], - "1748332325079.1748332325079.#sys.workflow_id#": ["sys", "workflow_id"], - } - - where `1748332301644` is the ID of the loop / iteration node, - and `1748332325079` is the ID of the node inside the loop or iteration node. - - Here, the key consists of two parts: the current node ID (provided as the `node_id` - parameter to `_extract_variable_selector_to_variable_mapping`) and the variable selector, - enclosed in `#` symbols. These two parts are separated by a dot (`.`). - - The value is a list of string representing the variable selector, where the first element is the node ID - of the referenced variable, and the second element is the variable name within that node. - - The meaning of the above response is: - - The node with ID `1747829548239` references the variable `result` from the node with - ID `1747829667553`. For example, if `1747829548239` is a LLM node, its prompt may contain a - reference to the `result` output variable of node `1747829667553`. - - :param graph_config: graph config - :param config: node config - :return: - """ - node_id = config["id"] - node_data = cls.validate_node_data(config["data"]) - data = cls._extract_variable_selector_to_variable_mapping( - graph_config=graph_config, - node_id=node_id, - node_data=node_data, - ) - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: NodeDataT, - ) -> Mapping[str, Sequence[str]]: - return {} - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this node blocks the output of specific variables. - - This method is used to determine if a node must complete execution before - the specified variables can be used in streaming output. - - :param variable_selectors: Set of variable selectors, each as a tuple (e.g., ('conversation', 'str')) - :return: True if this node blocks output of any of the specified variables, False otherwise - """ - return False - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return {} - - @classmethod - @abstractmethod - def version(cls) -> str: - """`node_version` returns the version of current node type.""" - # NOTE(QuantumGhost): Node versions must remain unique per `NodeType` so - # registry lookups can resolve numeric versions and `latest`. - raise NotImplementedError("subclasses of BaseNode must implement `version` method.") - - @classmethod - def get_node_type_classes_mapping(cls) -> Mapping[NodeType, Mapping[str, type[Node]]]: - """Return a read-only view of the currently registered node classes. - - This accessor intentionally performs no imports. The embedding layer that - owns bootstrap (for example `core.workflow.node_factory`) must import any - extension node packages before calling it so their subclasses register via - `__init_subclass__`. - """ - return {node_type: MappingProxyType(version_map) for node_type, version_map in cls._registry.items()} - - @property - def retry(self) -> bool: - return False - - def _get_error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._node_data.error_strategy - - def _get_retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._node_data.retry_config - - def _get_title(self) -> str: - """Get the node title.""" - return self._node_data.title - - def _get_description(self) -> str | None: - """Get the node description.""" - return self._node_data.desc - - def _get_default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._node_data.default_value_dict - - # Public interface properties that delegate to abstract methods - @property - def error_strategy(self) -> ErrorStrategy | None: - """Get the error strategy for this node.""" - return self._get_error_strategy() - - @property - def retry_config(self) -> RetryConfig: - """Get the retry configuration for this node.""" - return self._get_retry_config() - - @property - def title(self) -> str: - """Get the node title.""" - return self._get_title() - - @property - def description(self) -> str | None: - """Get the node description.""" - return self._get_description() - - @property - def default_value_dict(self) -> dict[str, Any]: - """Get the default values dictionary for this node.""" - return self._get_default_value_dict() - - @property - def node_data(self) -> NodeDataT: - """Typed access to this node's configuration data.""" - return self._node_data - - def _convert_node_run_result_to_graph_node_event(self, result: NodeRunResult) -> GraphNodeEventBase: - finished_at = naive_utc_now() - match result.status: - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - error=result.error, - ) - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=result, - ) - case _: - raise Exception(f"result status {result.status} not supported") - - @singledispatchmethod - def _dispatch(self, event: NodeEventBase) -> GraphNodeEventBase: - raise NotImplementedError(f"Node {self._node_id} does not support event type {type(event)}") - - @_dispatch.register - def _(self, event: StreamChunkEvent) -> NodeRunStreamChunkEvent: - return NodeRunStreamChunkEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - selector=event.selector, - chunk=event.chunk, - is_final=event.is_final, - ) - - @_dispatch.register - def _(self, event: StreamCompletedEvent) -> NodeRunSucceededEvent | NodeRunFailedEvent: - finished_at = naive_utc_now() - match event.node_run_result.status: - case WorkflowNodeExecutionStatus.SUCCEEDED: - return NodeRunSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - ) - case WorkflowNodeExecutionStatus.FAILED: - return NodeRunFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - start_at=self._start_at, - finished_at=finished_at, - node_run_result=event.node_run_result, - error=event.node_run_result.error, - ) - case _: - raise NotImplementedError( - f"Node {self._node_id} does not support status {event.node_run_result.status}" - ) - - @_dispatch.register - def _(self, event: PauseRequestedEvent) -> NodeRunPauseRequestedEvent: - return NodeRunPauseRequestedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.PAUSED), - reason=event.reason, - ) - - @_dispatch.register - def _(self, event: AgentLogEvent) -> NodeRunAgentLogEvent: - return NodeRunAgentLogEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - message_id=event.message_id, - label=event.label, - node_execution_id=event.node_execution_id, - parent_id=event.parent_id, - error=event.error, - status=event.status, - data=event.data, - metadata=event.metadata, - ) - - @_dispatch.register - def _(self, event: HumanInputFormFilledEvent): - return NodeRunHumanInputFormFilledEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - rendered_content=event.rendered_content, - action_id=event.action_id, - action_text=event.action_text, - ) - - @_dispatch.register - def _(self, event: HumanInputFormTimeoutEvent): - return NodeRunHumanInputFormTimeoutEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=event.node_title, - expiration_time=event.expiration_time, - ) - - @_dispatch.register - def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent: - return NodeRunLoopStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: LoopNextEvent) -> NodeRunLoopNextEvent: - return NodeRunLoopNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_loop_output=event.pre_loop_output, - ) - - @_dispatch.register - def _(self, event: LoopSucceededEvent) -> NodeRunLoopSucceededEvent: - return NodeRunLoopSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: LoopFailedEvent) -> NodeRunLoopFailedEvent: - return NodeRunLoopFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: IterationStartedEvent) -> NodeRunIterationStartedEvent: - return NodeRunIterationStartedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - metadata=event.metadata, - predecessor_node_id=event.predecessor_node_id, - ) - - @_dispatch.register - def _(self, event: IterationNextEvent) -> NodeRunIterationNextEvent: - return NodeRunIterationNextEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - index=event.index, - pre_iteration_output=event.pre_iteration_output, - ) - - @_dispatch.register - def _(self, event: IterationSucceededEvent) -> NodeRunIterationSucceededEvent: - return NodeRunIterationSucceededEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - ) - - @_dispatch.register - def _(self, event: IterationFailedEvent) -> NodeRunIterationFailedEvent: - return NodeRunIterationFailedEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - node_title=self.node_data.title, - start_at=event.start_at, - inputs=event.inputs, - outputs=event.outputs, - metadata=event.metadata, - steps=event.steps, - error=event.error, - ) - - @_dispatch.register - def _(self, event: RunRetrieverResourceEvent) -> NodeRunRetrieverResourceEvent: - from core.rag.entities.citation_metadata import RetrievalSourceMetadata - - retriever_resources = [ - RetrievalSourceMetadata.model_validate(resource) for resource in event.retriever_resources - ] - return NodeRunRetrieverResourceEvent( - id=self.execution_id, - node_id=self._node_id, - node_type=self.node_type, - retriever_resources=retriever_resources, - context=event.context, - node_version=self.version(), - ) diff --git a/api/dify_graph/nodes/base/template.py b/api/dify_graph/nodes/base/template.py deleted file mode 100644 index 5976e808e3..0000000000 --- a/api/dify_graph/nodes/base/template.py +++ /dev/null @@ -1,150 +0,0 @@ -"""Template structures for Response nodes (Answer and End). - -This module provides a unified template structure for both Answer and End nodes, -similar to SegmentGroup but focused on template representation without values. -""" - -from __future__ import annotations - -from abc import ABC, abstractmethod -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Any, Union - -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser - - -@dataclass(frozen=True) -class TemplateSegment(ABC): - """Base class for template segments.""" - - @abstractmethod - def __str__(self) -> str: - """String representation of the segment.""" - pass - - -@dataclass(frozen=True) -class TextSegment(TemplateSegment): - """A text segment in a template.""" - - text: str - - def __str__(self) -> str: - return self.text - - -@dataclass(frozen=True) -class VariableSegment(TemplateSegment): - """A variable reference segment in a template.""" - - selector: Sequence[str] - variable_name: str | None = None # Optional variable name for End nodes - - def __str__(self) -> str: - return "{{#" + ".".join(self.selector) + "#}}" - - -# Type alias for segments -TemplateSegmentUnion = Union[TextSegment, VariableSegment] - - -@dataclass(frozen=True) -class Template: - """Unified template structure for Response nodes. - - Similar to SegmentGroup, but represents the template structure - without variable values - only marking variable selectors. - """ - - segments: list[TemplateSegmentUnion] - - @classmethod - def from_answer_template(cls, template_str: str) -> Template: - """Create a Template from an Answer node template string. - - Example: - "Hello, {{#node1.name#}}" -> [TextSegment("Hello, "), VariableSegment(["node1", "name"])] - - Args: - template_str: The answer template string - - Returns: - Template instance - """ - parser = VariableTemplateParser(template_str) - segments: list[TemplateSegmentUnion] = [] - - # Extract variable selectors to find all variables - variable_selectors = parser.extract_variable_selectors() - var_map = {var.variable: var.value_selector for var in variable_selectors} - - # Parse template to get ordered segments - # We need to split the template by variable placeholders while preserving order - import re - - # Create a regex pattern that matches variable placeholders - pattern = r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}" - - # Split template while keeping the delimiters (variable placeholders) - parts = re.split(pattern, template_str) - - for i, part in enumerate(parts): - if not part: - continue - - # Check if this part is a variable reference (odd indices after split) - if i % 2 == 1: # Odd indices are variable keys - # Remove the # symbols from the variable key - var_key = part - if var_key in var_map: - segments.append(VariableSegment(selector=list(var_map[var_key]))) - else: - # This shouldn't happen with valid templates - segments.append(TextSegment(text="{{" + part + "}}")) - else: - # Even indices are text segments - segments.append(TextSegment(text=part)) - - return cls(segments=segments) - - @classmethod - def from_end_outputs(cls, outputs_config: list[dict[str, Any]]) -> Template: - """Create a Template from an End node outputs configuration. - - End nodes are treated as templates of concatenated variables with newlines. - - Example: - [{"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}] - -> - [VariableSegment(["node1", "text"]), - TextSegment("\n"), - VariableSegment(["node2", "result"])] - - Args: - outputs_config: List of output configurations with variable and value_selector - - Returns: - Template instance - """ - segments: list[TemplateSegmentUnion] = [] - - for i, output in enumerate(outputs_config): - if i > 0: - # Add newline separator between variables - segments.append(TextSegment(text="\n")) - - value_selector = output.get("value_selector", []) - variable_name = output.get("variable", "") - if value_selector: - segments.append(VariableSegment(selector=list(value_selector), variable_name=variable_name)) - - if len(segments) > 0 and isinstance(segments[-1], TextSegment): - segments = segments[:-1] - - return cls(segments=segments) - - def __str__(self) -> str: - """String representation of the template.""" - return "".join(str(segment) for segment in self.segments) diff --git a/api/dify_graph/nodes/base/usage_tracking_mixin.py b/api/dify_graph/nodes/base/usage_tracking_mixin.py deleted file mode 100644 index bd49419fd3..0000000000 --- a/api/dify_graph/nodes/base/usage_tracking_mixin.py +++ /dev/null @@ -1,28 +0,0 @@ -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState - - -class LLMUsageTrackingMixin: - """Provides shared helpers for merging and recording LLM usage within workflow nodes.""" - - graph_runtime_state: GraphRuntimeState - - @staticmethod - def _merge_usage(current: LLMUsage, new_usage: LLMUsage | None) -> LLMUsage: - """Return a combined usage snapshot, preserving zero-value inputs.""" - if new_usage is None or new_usage.total_tokens <= 0: - return current - if current.total_tokens == 0: - return new_usage - return current.plus(new_usage) - - def _accumulate_usage(self, usage: LLMUsage) -> None: - """Push usage into the graph runtime accumulator for downstream reporting.""" - if usage.total_tokens <= 0: - return - - current_usage = self.graph_runtime_state.llm_usage - if current_usage.total_tokens == 0: - self.graph_runtime_state.llm_usage = usage.model_copy() - else: - self.graph_runtime_state.llm_usage = current_usage.plus(usage) diff --git a/api/dify_graph/nodes/base/variable_template_parser.py b/api/dify_graph/nodes/base/variable_template_parser.py deleted file mode 100644 index de5e619e8c..0000000000 --- a/api/dify_graph/nodes/base/variable_template_parser.py +++ /dev/null @@ -1,130 +0,0 @@ -import re -from collections.abc import Mapping, Sequence -from typing import Any - -from .entities import VariableSelector - -REGEX = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - -SELECTOR_PATTERN = re.compile(r"\{\{(#[a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10}#)\}\}") - - -def extract_selectors_from_template(template: str, /) -> Sequence[VariableSelector]: - parts = SELECTOR_PATTERN.split(template) - selectors = [] - for part in filter(lambda x: x, parts): - if "." in part and part[0] == "#" and part[-1] == "#": - selectors.append(VariableSelector(variable=f"{part}", value_selector=part[1:-1].split("."))) - return selectors - - -class VariableTemplateParser: - """ - !NOTE: Consider to use the new `segments` module instead of this class. - - A class for parsing and manipulating template variables in a string. - - Rules: - - 1. Template variables must be enclosed in `{{}}`. - 2. The template variable Key can only be: #node_id.var1.var2#. - 3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2. - - Example usage: - - template = "Hello, {{#node_id.query.name#}}! Your age is {{#node_id.query.age#}}." - parser = VariableTemplateParser(template) - - # Extract template variable keys - variable_keys = parser.extract() - print(variable_keys) - # Output: ['#node_id.query.name#', '#node_id.query.age#'] - - # Extract variable selectors - variable_selectors = parser.extract_variable_selectors() - print(variable_selectors) - # Output: [VariableSelector(variable='#node_id.query.name#', value_selector=['node_id', 'query', 'name']), - # VariableSelector(variable='#node_id.query.age#', value_selector=['node_id', 'query', 'age'])] - - # Format the template string - inputs = {'#node_id.query.name#': 'John', '#node_id.query.age#': 25}} - formatted_string = parser.format(inputs) - print(formatted_string) - # Output: "Hello, John! Your age is 25." - """ - - def __init__(self, template: str): - self.template = template - self.variable_keys = self.extract() - - def extract(self): - """ - Extracts all the template variable keys from the template string. - - Returns: - A list of template variable keys. - """ - # Regular expression to match the template rules - matches = re.findall(REGEX, self.template) - - first_group_matches = [match[0] for match in matches] - - return list(set(first_group_matches)) - - def extract_variable_selectors(self) -> list[VariableSelector]: - """ - Extracts the variable selectors from the template variable keys. - - Returns: - A list of VariableSelector objects representing the variable selectors. - """ - variable_selectors = [] - for variable_key in self.variable_keys: - remove_hash = variable_key.replace("#", "") - split_result = remove_hash.split(".") - if len(split_result) < 2: - continue - - variable_selectors.append(VariableSelector(variable=variable_key, value_selector=split_result)) - - return variable_selectors - - def format(self, inputs: Mapping[str, Any]) -> str: - """ - Formats the template string by replacing the template variables with their corresponding values. - - Args: - inputs: A dictionary containing the values for the template variables. - - Returns: - The formatted string with template variables replaced by their values. - """ - - def replacer(match): - key = match.group(1) - value = inputs.get(key, match.group(0)) # return original matched string if key not found - - if value is None: - value = "" - # convert the value to string - if isinstance(value, list | dict | bool | int | float): - value = str(value) - - # remove template variables if required - return VariableTemplateParser.remove_template_variables(value) - - prompt = re.sub(REGEX, replacer, self.template) - return re.sub(r"<\|.*?\|>", "", prompt) - - @classmethod - def remove_template_variables(cls, text: str): - """ - Removes the template variables from the given text. - - Args: - text: The text from which to remove the template variables. - - Returns: - The text with template variables removed. - """ - return re.sub(REGEX, r"{\1}", text) diff --git a/api/dify_graph/nodes/code/__init__.py b/api/dify_graph/nodes/code/__init__.py deleted file mode 100644 index 8c6dcc7fcc..0000000000 --- a/api/dify_graph/nodes/code/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .code_node import CodeNode - -__all__ = ["CodeNode"] diff --git a/api/dify_graph/nodes/code/code_node.py b/api/dify_graph/nodes/code/code_node.py deleted file mode 100644 index 82d5fced62..0000000000 --- a/api/dify_graph/nodes/code/code_node.py +++ /dev/null @@ -1,493 +0,0 @@ -from collections.abc import Mapping, Sequence -from decimal import Decimal -from textwrap import dedent -from typing import TYPE_CHECKING, Any, Protocol, cast - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.variables.segments import ArrayFileSegment -from dify_graph.variables.types import SegmentType - -from .exc import ( - CodeNodeError, - DepthLimitError, - OutputValidationError, -) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - -class WorkflowCodeExecutor(Protocol): - def execute( - self, - *, - language: CodeLanguage, - code: str, - inputs: Mapping[str, Any], - ) -> Mapping[str, Any]: ... - - def is_execution_error(self, error: Exception) -> bool: ... - - -def _build_default_config(*, language: CodeLanguage, code: str) -> Mapping[str, object]: - return { - "type": "code", - "config": { - "variables": [ - {"variable": "arg1", "value_selector": []}, - {"variable": "arg2", "value_selector": []}, - ], - "code_language": language, - "code": code, - "outputs": {"result": {"type": "string", "children": None}}, - }, - } - - -_DEFAULT_CODE_BY_LANGUAGE: Mapping[CodeLanguage, str] = { - CodeLanguage.PYTHON3: dedent( - """ - def main(arg1: str, arg2: str): - return { - "result": arg1 + arg2, - } - """ - ), - CodeLanguage.JAVASCRIPT: dedent( - """ - function main({arg1, arg2}) { - return { - result: arg1 + arg2 - } - } - """ - ), -} - - -class CodeNode(Node[CodeNodeData]): - node_type = BuiltinNodeTypes.CODE - _limits: CodeNodeLimits - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - code_executor: WorkflowCodeExecutor, - code_limits: CodeNodeLimits, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._code_executor: WorkflowCodeExecutor = code_executor - self._limits = code_limits - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - code_language = CodeLanguage.PYTHON3 - if filters: - code_language = cast(CodeLanguage, filters.get("code_language", CodeLanguage.PYTHON3)) - - default_code = _DEFAULT_CODE_BY_LANGUAGE.get(code_language) - if default_code is None: - raise CodeNodeError(f"Unsupported code language: {code_language}") - return _build_default_config(language=code_language, code=default_code) - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get code language - code_language = self.node_data.code_language - code = self.node_data.code - - # Get variables - variables = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if isinstance(variable, ArrayFileSegment): - variables[variable_name] = [v.to_dict() for v in variable.value] if variable.value else None - else: - variables[variable_name] = variable.to_object() if variable else None - # Run code - try: - result = self._code_executor.execute( - language=code_language, - code=code, - inputs=variables, - ) - - # Transform result - result = self._transform_result(result=result, output_schema=self.node_data.outputs) - except CodeNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - except Exception as e: - if not self._code_executor.is_execution_error(e): - raise - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=variables, error=str(e), error_type=type(e).__name__ - ) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs=result) - - def _check_string(self, value: str | None, variable: str) -> str | None: - """ - Check string - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if len(value) > self._limits.max_string_length: - raise OutputValidationError( - f"The length of output variable `{variable}` must be" - f" less than {self._limits.max_string_length} characters" - ) - - return value.replace("\x00", "") - - def _check_boolean(self, value: bool | None, variable: str) -> bool | None: - if value is None: - return None - - return value - - def _check_number(self, value: int | float | None, variable: str) -> int | float | None: - """ - Check number - :param value: value - :param variable: variable - :return: - """ - if value is None: - return None - - if value > self._limits.max_number or value < self._limits.min_number: - raise OutputValidationError( - f"Output variable `{variable}` is out of range," - f" it must be between {self._limits.min_number} and {self._limits.max_number}." - ) - - if isinstance(value, float): - decimal_value = Decimal(str(value)).normalize() - precision = -decimal_value.as_tuple().exponent if decimal_value.as_tuple().exponent < 0 else 0 # type: ignore[operator] - # raise error if precision is too high - if precision > self._limits.max_precision: - raise OutputValidationError( - f"Output variable `{variable}` has too high precision," - f" it must be less than {self._limits.max_precision} digits." - ) - - return value - - def _transform_result( - self, - result: Mapping[str, Any], - output_schema: dict[str, CodeNodeData.Output] | None, - prefix: str = "", - depth: int = 1, - ): - # TODO(QuantumGhost): Replace native Python lists with `Array*Segment` classes. - # Note that `_transform_result` may produce lists containing `None` values, - # which don't conform to the type requirements of `Array*Segment` classes. - if depth > self._limits.max_depth: - raise DepthLimitError(f"Depth limit {self._limits.max_depth} reached, object too deep.") - - transformed_result: dict[str, Any] = {} - if output_schema is None: - # validate output thought instance type - for output_name, output_value in result.items(): - if isinstance(output_value, dict): - self._transform_result( - result=output_value, - output_schema=None, - prefix=f"{prefix}.{output_name}" if prefix else output_name, - depth=depth + 1, - ) - elif isinstance(output_value, bool): - self._check_boolean(output_value, variable=f"{prefix}.{output_name}" if prefix else output_name) - elif isinstance(output_value, int | float): - self._check_number( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, str): - self._check_string( - value=output_value, variable=f"{prefix}.{output_name}" if prefix else output_name - ) - elif isinstance(output_value, list): - first_element = output_value[0] if len(output_value) > 0 else None - if first_element is not None: - if isinstance(first_element, int | float) and all( - value is None or isinstance(value, int | float) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_number( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif isinstance(first_element, str) and all( - value is None or isinstance(value, str) for value in output_value - ): - for i, value in enumerate(output_value): - self._check_string( - value=value, - variable=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - ) - elif ( - isinstance(first_element, dict) - and all(value is None or isinstance(value, dict) for value in output_value) - or isinstance(first_element, list) - and all(value is None or isinstance(value, list) for value in output_value) - ): - for i, value in enumerate(output_value): - if value is not None: - self._transform_result( - result=value, - output_schema=None, - prefix=f"{prefix}.{output_name}[{i}]" if prefix else f"{output_name}[{i}]", - depth=depth + 1, - ) - else: - raise OutputValidationError( - f"Output {prefix}.{output_name} is not a valid array." - f" make sure all elements are of the same type." - ) - elif output_value is None: - pass - else: - raise OutputValidationError(f"Output {prefix}.{output_name} is not a valid type.") - - return result - - parameters_validated = {} - for output_name, output_config in output_schema.items(): - dot = "." if prefix else "" - if output_name not in result: - raise OutputValidationError(f"Output {prefix}{dot}{output_name} is missing.") - - if output_config.type == SegmentType.OBJECT: - # check if output is object - if not isinstance(result.get(output_name), dict): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an object," - f" got {type(result.get(output_name))} instead." - ) - else: - transformed_result[output_name] = self._transform_result( - result=result[output_name], - output_schema=output_config.children, - prefix=f"{prefix}.{output_name}", - depth=depth + 1, - ) - elif output_config.type == SegmentType.NUMBER: - # check if number available - value = result.get(output_name) - if value is not None and not isinstance(value, (int, float)): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not a number," - f" got {type(result.get(output_name))} instead." - ) - checked = self._check_number(value=value, variable=f"{prefix}{dot}{output_name}") - # If the output is a boolean and the output schema specifies a NUMBER type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - transformed_result[output_name] = self._convert_boolean_to_int(checked) - - elif output_config.type == SegmentType.STRING: - # check if string available - value = result.get(output_name) - if value is not None and not isinstance(value, str): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} must be a string, got {type(value).__name__} instead" - ) - transformed_result[output_name] = self._check_string( - value=value, - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.BOOLEAN: - transformed_result[output_name] = self._check_boolean( - value=result[output_name], - variable=f"{prefix}{dot}{output_name}", - ) - elif output_config.type == SegmentType.ARRAY_NUMBER: - # check if array of number available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array, got {type(value)} instead." - ) - else: - if len(value) > self._limits.max_number_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_number_array_length} elements." - ) - - for i, inner_value in enumerate(value): - if not isinstance(inner_value, (int, float)): - raise OutputValidationError( - f"The element at index {i} of output variable `{prefix}{dot}{output_name}` must be" - f" a number." - ) - _ = self._check_number(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = [ - # If the element is a boolean and the output schema specifies a `array[number]` type, - # convert the boolean value to an integer. - # - # This ensures compatibility with existing workflows that may use - # `True` and `False` as values for NUMBER type outputs. - self._convert_boolean_to_int(v) - for v in value - ] - elif output_config.type == SegmentType.ARRAY_STRING: - # check if array of string available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_string_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_string_array_length} elements." - ) - - transformed_result[output_name] = [ - self._check_string(value=value, variable=f"{prefix}{dot}{output_name}[{i}]") - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_OBJECT: - # check if array of object available - if not isinstance(result[output_name], list): - if result[output_name] is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - if len(result[output_name]) > self._limits.max_object_array_length: - raise OutputValidationError( - f"The length of output variable `{prefix}{dot}{output_name}` must be" - f" less than {self._limits.max_object_array_length} elements." - ) - - for i, value in enumerate(result[output_name]): - if not isinstance(value, dict): - if value is None: - pass - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not an object," - f" got {type(value)} instead at index {i}." - ) - - transformed_result[output_name] = [ - None - if value is None - else self._transform_result( - result=value, - output_schema=output_config.children, - prefix=f"{prefix}{dot}{output_name}[{i}]", - depth=depth + 1, - ) - for i, value in enumerate(result[output_name]) - ] - elif output_config.type == SegmentType.ARRAY_BOOLEAN: - # check if array of object available - value = result[output_name] - if not isinstance(value, list): - if value is None: - transformed_result[output_name] = None - else: - raise OutputValidationError( - f"Output {prefix}{dot}{output_name} is not an array," - f" got {type(result.get(output_name))} instead." - ) - else: - for i, inner_value in enumerate(value): - if inner_value is not None and not isinstance(inner_value, bool): - raise OutputValidationError( - f"Output {prefix}{dot}{output_name}[{i}] is not a boolean," - f" got {type(inner_value)} instead." - ) - _ = self._check_boolean(value=inner_value, variable=f"{prefix}{dot}{output_name}[{i}]") - transformed_result[output_name] = value - - else: - raise OutputValidationError(f"Output type {output_config.type} is not supported.") - - parameters_validated[output_name] = True - - # check if all output parameters are validated - if len(parameters_validated) != len(result): - raise CodeNodeError("Not all output parameters are validated.") - - return transformed_result - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: CodeNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @staticmethod - def _convert_boolean_to_int(value: bool | int | float | None) -> int | float | None: - """This function convert boolean to integers when the output schema specifies a NUMBER type. - - This ensures compatibility with existing workflows that may use - `True` and `False` as values for NUMBER type outputs. - """ - if value is None: - return None - if isinstance(value, bool): - return int(value) - return value diff --git a/api/dify_graph/nodes/code/entities.py b/api/dify_graph/nodes/code/entities.py deleted file mode 100644 index 55b4ee4862..0000000000 --- a/api/dify_graph/nodes/code/entities.py +++ /dev/null @@ -1,57 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Literal - -from pydantic import AfterValidator, BaseModel - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.variables.types import SegmentType - - -class CodeLanguage(StrEnum): - PYTHON3 = "python3" - JINJA2 = "jinja2" - JAVASCRIPT = "javascript" - - -_ALLOWED_OUTPUT_FROM_CODE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _validate_type(segment_type: SegmentType) -> SegmentType: - if segment_type not in _ALLOWED_OUTPUT_FROM_CODE: - raise ValueError(f"invalid type for code output, expected {_ALLOWED_OUTPUT_FROM_CODE}, actual {segment_type}") - return segment_type - - -class CodeNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.CODE - - class Output(BaseModel): - type: Annotated[SegmentType, AfterValidator(_validate_type)] - children: dict[str, "CodeNodeData.Output"] | None = None - - class Dependency(BaseModel): - name: str - version: str - - variables: list[VariableSelector] - code_language: Literal[CodeLanguage.PYTHON3, CodeLanguage.JAVASCRIPT] - code: str - outputs: dict[str, Output] - dependencies: list[Dependency] | None = None diff --git a/api/dify_graph/nodes/code/exc.py b/api/dify_graph/nodes/code/exc.py deleted file mode 100644 index d6334fd554..0000000000 --- a/api/dify_graph/nodes/code/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class CodeNodeError(ValueError): - """Base class for code node errors.""" - - pass - - -class OutputValidationError(CodeNodeError): - """Raised when there is an output validation error.""" - - pass - - -class DepthLimitError(CodeNodeError): - """Raised when the depth limit is reached.""" - - pass diff --git a/api/dify_graph/nodes/code/limits.py b/api/dify_graph/nodes/code/limits.py deleted file mode 100644 index a6b9e9e68e..0000000000 --- a/api/dify_graph/nodes/code/limits.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - - -@dataclass(frozen=True) -class CodeNodeLimits: - max_string_length: int - max_number: int | float - min_number: int | float - max_precision: int - max_depth: int - max_number_array_length: int - max_string_array_length: int - max_object_array_length: int diff --git a/api/dify_graph/nodes/document_extractor/__init__.py b/api/dify_graph/nodes/document_extractor/__init__.py deleted file mode 100644 index 9922e3949d..0000000000 --- a/api/dify_graph/nodes/document_extractor/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .node import DocumentExtractorNode - -__all__ = ["DocumentExtractorNode", "DocumentExtractorNodeData", "UnstructuredApiConfig"] diff --git a/api/dify_graph/nodes/document_extractor/entities.py b/api/dify_graph/nodes/document_extractor/entities.py deleted file mode 100644 index 1110cc2710..0000000000 --- a/api/dify_graph/nodes/document_extractor/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class DocumentExtractorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - variable_selector: Sequence[str] - - -@dataclass(frozen=True) -class UnstructuredApiConfig: - api_url: str | None = None - api_key: str = "" diff --git a/api/dify_graph/nodes/document_extractor/exc.py b/api/dify_graph/nodes/document_extractor/exc.py deleted file mode 100644 index 5caf00ebc5..0000000000 --- a/api/dify_graph/nodes/document_extractor/exc.py +++ /dev/null @@ -1,14 +0,0 @@ -class DocumentExtractorError(ValueError): - """Base exception for errors related to the DocumentExtractorNode.""" - - -class FileDownloadError(DocumentExtractorError): - """Exception raised when there's an error downloading a file.""" - - -class UnsupportedFileTypeError(DocumentExtractorError): - """Exception raised when trying to extract text from an unsupported file type.""" - - -class TextExtractionError(DocumentExtractorError): - """Exception raised when there's an error during text extraction from a file.""" diff --git a/api/dify_graph/nodes/document_extractor/node.py b/api/dify_graph/nodes/document_extractor/node.py deleted file mode 100644 index 27196f1aca..0000000000 --- a/api/dify_graph/nodes/document_extractor/node.py +++ /dev/null @@ -1,782 +0,0 @@ -import csv -import io -import json -import logging -import os -import tempfile -import zipfile -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -import charset_normalizer -import docx -import pandas as pd -import pypandoc -import pypdfium2 -import webvtt -import yaml -from docx.document import Document -from docx.oxml.table import CT_Tbl -from docx.oxml.text.paragraph import CT_P -from docx.table import Table -from docx.text.paragraph import Paragraph - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, file_manager -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayStringSegment, FileSegment - -from .entities import DocumentExtractorNodeData, UnstructuredApiConfig -from .exc import DocumentExtractorError, FileDownloadError, TextExtractionError, UnsupportedFileTypeError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - -class DocumentExtractorNode(Node[DocumentExtractorNodeData]): - """ - Extracts text content from various file types. - Supports plain text, PDF, and DOC/DOCX files. - """ - - node_type = BuiltinNodeTypes.DOCUMENT_EXTRACTOR - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - unstructured_api_config: UnstructuredApiConfig | None = None, - http_client: HttpClientProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._unstructured_api_config = unstructured_api_config or UnstructuredApiConfig() - self._http_client = http_client - - def _run(self): - variable_selector = self.node_data.variable_selector - variable = self.graph_runtime_state.variable_pool.get(variable_selector) - - if variable is None: - error_message = f"File variable not found for selector: {variable_selector}" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - if variable.value and not isinstance(variable, ArrayFileSegment | FileSegment): - error_message = f"Variable {variable_selector} is not an ArrayFileSegment" - return NodeRunResult(status=WorkflowNodeExecutionStatus.FAILED, error=error_message) - - value = variable.value - inputs = {"variable_selector": variable_selector} - if isinstance(value, list): - value = list(filter(lambda x: x, value)) - process_data = {"documents": value if isinstance(value, list) else [value]} - - if not value: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=[])}, - ) - - try: - if isinstance(value, list): - extracted_text_list = [ - _extract_text_from_file( - self._http_client, file, unstructured_api_config=self._unstructured_api_config - ) - for file in value - ] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": ArrayStringSegment(value=extracted_text_list)}, - ) - elif isinstance(value, File): - extracted_text = _extract_text_from_file( - self._http_client, value, unstructured_api_config=self._unstructured_api_config - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={"text": extracted_text}, - ) - else: - raise DocumentExtractorError(f"Unsupported variable type: {type(value)}") - except DocumentExtractorError as e: - logger.warning(e, exc_info=True) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: DocumentExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - return {node_id + ".files": node_data.variable_selector} - - -def _extract_text_by_mime_type( - *, - file_content: bytes, - mime_type: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its MIME type.""" - match mime_type: - case "text/plain" | "text/html" | "text/htm" | "text/markdown" | "text/xml": - return _extract_text_from_plain_text(file_content) - case "application/pdf": - return _extract_text_from_pdf(file_content) - case "application/msword": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.wordprocessingml.document": - return _extract_text_from_docx(file_content) - case "text/csv": - return _extract_text_from_csv(file_content) - case "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" | "application/vnd.ms-excel": - return _extract_text_from_excel(file_content) - case "application/vnd.ms-powerpoint": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case "application/vnd.openxmlformats-officedocument.presentationml.presentation": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case "application/epub+zip": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case "message/rfc822": - return _extract_text_from_eml(file_content) - case "application/vnd.ms-outlook": - return _extract_text_from_msg(file_content) - case "application/json": - return _extract_text_from_json(file_content) - case "application/x-yaml" | "text/yaml": - return _extract_text_from_yaml(file_content) - case "text/vtt": - return _extract_text_from_vtt(file_content) - case "text/properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported MIME type: {mime_type}") - - -def _extract_text_by_file_extension( - *, - file_content: bytes, - file_extension: str, - unstructured_api_config: UnstructuredApiConfig, -) -> str: - """Extract text from a file based on its file extension.""" - match file_extension: - case ( - ".txt" - | ".markdown" - | ".md" - | ".mdx" - | ".html" - | ".htm" - | ".xml" - | ".c" - | ".h" - | ".cpp" - | ".hpp" - | ".cc" - | ".cxx" - | ".c++" - | ".py" - | ".js" - | ".ts" - | ".jsx" - | ".tsx" - | ".java" - | ".php" - | ".rb" - | ".go" - | ".rs" - | ".swift" - | ".kt" - | ".scala" - | ".sh" - | ".bash" - | ".bat" - | ".ps1" - | ".sql" - | ".r" - | ".m" - | ".pl" - | ".lua" - | ".vim" - | ".asm" - | ".s" - | ".css" - | ".scss" - | ".less" - | ".sass" - | ".ini" - | ".cfg" - | ".conf" - | ".toml" - | ".env" - | ".log" - | ".vtt" - ): - return _extract_text_from_plain_text(file_content) - case ".json": - return _extract_text_from_json(file_content) - case ".yaml" | ".yml": - return _extract_text_from_yaml(file_content) - case ".pdf": - return _extract_text_from_pdf(file_content) - case ".doc": - return _extract_text_from_doc(file_content, unstructured_api_config=unstructured_api_config) - case ".docx": - return _extract_text_from_docx(file_content) - case ".csv": - return _extract_text_from_csv(file_content) - case ".xls" | ".xlsx": - return _extract_text_from_excel(file_content) - case ".ppt": - return _extract_text_from_ppt(file_content, unstructured_api_config=unstructured_api_config) - case ".pptx": - return _extract_text_from_pptx(file_content, unstructured_api_config=unstructured_api_config) - case ".epub": - return _extract_text_from_epub(file_content, unstructured_api_config=unstructured_api_config) - case ".eml": - return _extract_text_from_eml(file_content) - case ".msg": - return _extract_text_from_msg(file_content) - case ".properties": - return _extract_text_from_properties(file_content) - case _: - raise UnsupportedFileTypeError(f"Unsupported Extension Type: {file_extension}") - - -def _extract_text_from_plain_text(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content, cp_isolation=["utf_8", "latin_1", "cp1252"]).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - return file_content.decode(encoding, errors="ignore") - except (UnicodeDecodeError, LookupError) as e: - # If decoding fails, try with utf-8 as last resort - try: - return file_content.decode("utf-8", errors="ignore") - except UnicodeDecodeError: - raise TextExtractionError(f"Failed to decode plain text file: {e}") from e - - -def _extract_text_from_json(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - json_data = json.loads(file_content.decode(encoding, errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, LookupError, json.JSONDecodeError) as e: - # If decoding fails, try with utf-8 as last resort - try: - json_data = json.loads(file_content.decode("utf-8", errors="ignore")) - return json.dumps(json_data, indent=2, ensure_ascii=False) - except (UnicodeDecodeError, json.JSONDecodeError): - raise TextExtractionError(f"Failed to decode or parse JSON file: {e}") from e - - -def _extract_text_from_yaml(file_content: bytes) -> str: - """Extract the content from yaml file""" - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - yaml_data = yaml.safe_load_all(file_content.decode(encoding, errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, LookupError, yaml.YAMLError) as e: - # If decoding fails, try with utf-8 as last resort - try: - yaml_data = yaml.safe_load_all(file_content.decode("utf-8", errors="ignore")) - return yaml.dump_all(yaml_data, allow_unicode=True, sort_keys=False) - except (UnicodeDecodeError, yaml.YAMLError): - raise TextExtractionError(f"Failed to decode or parse YAML file: {e}") from e - - -def _extract_text_from_pdf(file_content: bytes) -> str: - try: - pdf_file = io.BytesIO(file_content) - pdf_document = pypdfium2.PdfDocument(pdf_file, autoclose=True) - text = "" - for page in pdf_document: - text_page = page.get_textpage() - text += text_page.get_text_range() - text_page.close() - page.close() - return text - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PDF: {str(e)}") from e - - -def _extract_text_from_doc(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - """ - Extract text from a DOC file. - """ - from unstructured.partition.api import partition_via_api - - if not unstructured_api_config.api_url: - raise TextExtractionError("Unstructured API URL is not configured for DOC file processing.") - api_key = unstructured_api_config.api_key or "" - - try: - with tempfile.NamedTemporaryFile(suffix=".doc", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOC: {str(e)}") from e - - -def parser_docx_part(block, doc: Document, content_items, i): - if isinstance(block, CT_P): - content_items.append((i, "paragraph", Paragraph(block, doc))) - elif isinstance(block, CT_Tbl): - content_items.append((i, "table", Table(block, doc))) - - -def _normalize_docx_zip(file_content: bytes) -> bytes: - """ - Some DOCX files (e.g. exported by Evernote on Windows) are malformed: - ZIP entry names use backslash (\\) as path separator instead of the forward - slash (/) required by both the ZIP spec and OOXML. On Linux/Mac the entry - "word\\document.xml" is never found when python-docx looks for - "word/document.xml", which triggers a KeyError about a missing relationship. - - This function rewrites the ZIP in-memory, normalizing all entry names to - use forward slashes without touching any actual document content. - """ - try: - with zipfile.ZipFile(io.BytesIO(file_content), "r") as zin: - out_buf = io.BytesIO() - with zipfile.ZipFile(out_buf, "w", compression=zipfile.ZIP_DEFLATED) as zout: - for item in zin.infolist(): - data = zin.read(item.filename) - # Normalize backslash path separators to forward slash - item.filename = item.filename.replace("\\", "/") - zout.writestr(item, data) - return out_buf.getvalue() - except zipfile.BadZipFile: - # Not a valid zip — return as-is and let python-docx report the real error - return file_content - - -def _extract_text_from_docx(file_content: bytes) -> str: - """ - Extract text from a DOCX file. - For now support only paragraph and table add more if needed - """ - try: - doc_file = io.BytesIO(file_content) - try: - doc = docx.Document(doc_file) - except Exception as e: - logger.warning("Failed to parse DOCX, attempting to normalize ZIP entry paths: %s", e) - # Some DOCX files exported by tools like Evernote on Windows use - # backslash path separators in ZIP entries and/or single-quoted XML - # attributes, both of which break python-docx on Linux. Normalize and retry. - file_content = _normalize_docx_zip(file_content) - doc = docx.Document(io.BytesIO(file_content)) - text = [] - - # Keep track of paragraph and table positions - content_items: list[tuple[int, str, Table | Paragraph]] = [] - - it = iter(doc.element.body) - part = next(it, None) - i = 0 - while part is not None: - parser_docx_part(part, doc, content_items, i) - i = i + 1 - part = next(it, None) - - # Process sorted content - for _, item_type, item in content_items: - if item_type == "paragraph": - if isinstance(item, Table): - continue - text.append(item.text) - elif item_type == "table": - # Process tables - if not isinstance(item, Table): - continue - try: - # Check if any cell in the table has text - has_content = False - for row in item.rows: - if any(cell.text.strip() for cell in row.cells): - has_content = True - break - - if has_content: - cell_texts = [cell.text.replace("\n", "
") for cell in item.rows[0].cells] - markdown_table = f"| {' | '.join(cell_texts)} |\n" - markdown_table += f"| {' | '.join(['---'] * len(item.rows[0].cells))} |\n" - - for row in item.rows[1:]: - # Replace newlines with
in each cell - row_cells = [cell.text.replace("\n", "
") for cell in row.cells] - markdown_table += "| " + " | ".join(row_cells) + " |\n" - - text.append(markdown_table) - except Exception as e: - logger.warning("Failed to extract table from DOC: %s", e) - continue - - return "\n".join(text) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from DOCX: {str(e)}") from e - - -def _download_file_content(http_client: HttpClientProtocol, file: File) -> bytes: - """Download the content of a file based on its transfer method.""" - try: - if file.transfer_method == FileTransferMethod.REMOTE_URL: - if file.remote_url is None: - raise FileDownloadError("Missing URL for remote file") - response = http_client.get(file.remote_url) - response.raise_for_status() - return response.content - else: - return file_manager.download(file) - except Exception as e: - raise FileDownloadError(f"Error downloading file: {str(e)}") from e - - -def _extract_text_from_file( - http_client: HttpClientProtocol, file: File, *, unstructured_api_config: UnstructuredApiConfig -) -> str: - file_content = _download_file_content(http_client, file) - if file.extension: - extracted_text = _extract_text_by_file_extension( - file_content=file_content, - file_extension=file.extension, - unstructured_api_config=unstructured_api_config, - ) - elif file.mime_type: - extracted_text = _extract_text_by_mime_type( - file_content=file_content, - mime_type=file.mime_type, - unstructured_api_config=unstructured_api_config, - ) - else: - raise UnsupportedFileTypeError("Unable to determine file type: MIME type or file extension is missing") - return extracted_text - - -def _extract_text_from_csv(file_content: bytes) -> str: - try: - # Detect encoding using charset_normalizer - result = charset_normalizer.from_bytes(file_content).best() - if result: - encoding = result.encoding - else: - encoding = "utf-8" - - # Fallback to utf-8 if detection fails - if not encoding: - encoding = "utf-8" - - try: - csv_file = io.StringIO(file_content.decode(encoding, errors="ignore")) - except (UnicodeDecodeError, LookupError): - # If decoding fails, try with utf-8 as last resort - csv_file = io.StringIO(file_content.decode("utf-8", errors="ignore")) - - csv_reader = csv.reader(csv_file) - rows = list(csv_reader) - - if not rows: - return "" - - # Combine multi-line text in the header row - header_row = [cell.replace("\n", " ").replace("\r", "") for cell in rows[0]] - - # Create Markdown table - markdown_table = "| " + " | ".join(header_row) + " |\n" - markdown_table += "| " + " | ".join(["-" * len(col) for col in rows[0]]) + " |\n" - - # Process each data row and combine multi-line text in each cell - for row in rows[1:]: - processed_row = [cell.replace("\n", " ").replace("\r", "") for cell in row] - markdown_table += "| " + " | ".join(processed_row) + " |\n" - - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from CSV: {str(e)}") from e - - -def _extract_text_from_excel(file_content: bytes) -> str: - """Extract text from an Excel file using pandas.""" - - def _construct_markdown_table(df: pd.DataFrame) -> str: - """Manually construct a Markdown table from a DataFrame.""" - # Construct the header row - header_row = "| " + " | ".join(df.columns) + " |" - - # Construct the separator row - separator_row = "| " + " | ".join(["-" * len(col) for col in df.columns]) + " |" - - # Construct the data rows - data_rows = [] - for _, row in df.iterrows(): - data_row = "| " + " | ".join(map(str, row)) + " |" - data_rows.append(data_row) - - # Combine all rows into a single string - markdown_table = "\n".join([header_row, separator_row] + data_rows) - return markdown_table - - try: - excel_file = pd.ExcelFile(io.BytesIO(file_content)) - markdown_table = "" - for sheet_name in excel_file.sheet_names: - try: - df = excel_file.parse(sheet_name=sheet_name) - df.dropna(how="all", inplace=True) - - # Combine multi-line text in each cell into a single line - df = df.map(lambda x: " ".join(str(x).splitlines()) if isinstance(x, str) else x) - - # Combine multi-line text in column names into a single line - df.columns = pd.Index([" ".join(str(col).splitlines()) for col in df.columns]) - - # Manually construct the Markdown table - markdown_table += _construct_markdown_table(df) + "\n\n" - except Exception: - continue - return markdown_table - except Exception as e: - raise TextExtractionError(f"Failed to extract text from Excel file: {str(e)}") from e - - -def _extract_text_from_ppt(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.ppt import partition_ppt - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".ppt", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_ppt(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_pptx(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.pptx import partition_pptx - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".pptx", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - with io.BytesIO(file_content) as file: - elements = partition_pptx(file=file) - return "\n".join([getattr(element, "text", "") for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from PPTX: {str(e)}") from e - - -def _extract_text_from_epub(file_content: bytes, *, unstructured_api_config: UnstructuredApiConfig) -> str: - from unstructured.partition.api import partition_via_api - from unstructured.partition.epub import partition_epub - - api_key = unstructured_api_config.api_key or "" - - try: - if unstructured_api_config.api_url: - with tempfile.NamedTemporaryFile(suffix=".epub", delete=False) as temp_file: - temp_file.write(file_content) - temp_file.flush() - with open(temp_file.name, "rb") as file: - elements = partition_via_api( - file=file, - metadata_filename=temp_file.name, - api_url=unstructured_api_config.api_url, - api_key=api_key, - ) - os.unlink(temp_file.name) - else: - pypandoc.download_pandoc() - with io.BytesIO(file_content) as file: - elements = partition_epub(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EPUB: {str(e)}") from e - - -def _extract_text_from_eml(file_content: bytes) -> str: - from unstructured.partition.email import partition_email - - try: - with io.BytesIO(file_content) as file: - elements = partition_email(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from EML: {str(e)}") from e - - -def _extract_text_from_msg(file_content: bytes) -> str: - from unstructured.partition.msg import partition_msg - - try: - with io.BytesIO(file_content) as file: - elements = partition_msg(file=file) - return "\n".join([str(element) for element in elements]) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from MSG: {str(e)}") from e - - -def _extract_text_from_vtt(vtt_bytes: bytes) -> str: - text = _extract_text_from_plain_text(vtt_bytes) - - # remove bom - text = text.lstrip("\ufeff") - - raw_results = [] - for caption in webvtt.from_string(text): - raw_results.append((caption.voice, caption.text)) - - # Merge consecutive utterances by the same speaker - merged_results = [] - if raw_results: - current_speaker, current_text = raw_results[0] - - for i in range(1, len(raw_results)): - spk, txt = raw_results[i] - if spk is None: - merged_results.append((None, current_text)) - continue - - if spk == current_speaker: - # If it is the same speaker, merge the utterances (joined by space) - current_text += " " + txt - else: - # If the speaker changes, register the utterance so far and move on - merged_results.append((current_speaker, current_text)) - current_speaker, current_text = spk, txt - - # Add the last element - merged_results.append((current_speaker, current_text)) - else: - merged_results = raw_results - - # Return the result in the specified format: Speaker "text" style - formatted = [f'{spk or ""} "{txt}"' for spk, txt in merged_results] - return "\n".join(formatted) - - -def _extract_text_from_properties(file_content: bytes) -> str: - try: - text = _extract_text_from_plain_text(file_content) - lines = text.splitlines() - result = [] - for line in lines: - line = line.strip() - # Preserve comments and empty lines - if not line or line.startswith("#") or line.startswith("!"): - result.append(line) - continue - - if "=" in line: - key, value = line.split("=", 1) - elif ":" in line: - key, value = line.split(":", 1) - else: - key, value = line, "" - - result.append(f"{key.strip()}: {value.strip()}") - - return "\n".join(result) - except Exception as e: - raise TextExtractionError(f"Failed to extract text from properties file: {str(e)}") from e diff --git a/api/dify_graph/nodes/end/__init__.py b/api/dify_graph/nodes/end/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/nodes/end/end_node.py b/api/dify_graph/nodes/end/end_node.py deleted file mode 100644 index 1f5cfab22b..0000000000 --- a/api/dify_graph/nodes/end/end_node.py +++ /dev/null @@ -1,47 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.template import Template -from dify_graph.nodes.end.entities import EndNodeData - - -class EndNode(Node[EndNodeData]): - node_type = BuiltinNodeTypes.END - execution_type = NodeExecutionType.RESPONSE - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - collect all outputs at once. - - This method runs after streaming is complete (if streaming was enabled). - It collects all output variables and returns them. - """ - output_variables = self.node_data.outputs - - outputs = {} - for variable_selector in output_variables: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - value = variable.to_object() if variable is not None else None - outputs[variable_selector.variable] = value - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=outputs, - outputs=outputs, - ) - - def get_streaming_template(self) -> Template: - """ - Get the template for streaming. - - Returns: - Template instance for this End node - """ - outputs_config = [ - {"variable": output.variable, "value_selector": output.value_selector} for output in self.node_data.outputs - ] - return Template.from_end_outputs(outputs_config) diff --git a/api/dify_graph/nodes/end/entities.py b/api/dify_graph/nodes/end/entities.py deleted file mode 100644 index be7f0c8de8..0000000000 --- a/api/dify_graph/nodes/end/entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import OutputVariableEntity - - -class EndNodeData(BaseNodeData): - """ - END Node Data. - """ - - type: NodeType = BuiltinNodeTypes.END - outputs: list[OutputVariableEntity] - - -class EndStreamParam(BaseModel): - """ - EndStreamParam entity - """ - - end_dependencies: dict[str, list[str]] = Field( - ..., description="end dependencies (end node id -> dependent node ids)" - ) - end_stream_variable_selector_mapping: dict[str, list[list[str]]] = Field( - ..., description="end stream variable selector mapping (end node id -> stream variable selectors)" - ) diff --git a/api/dify_graph/nodes/http_request/__init__.py b/api/dify_graph/nodes/http_request/__init__.py deleted file mode 100644 index b29099db23..0000000000 --- a/api/dify_graph/nodes/http_request/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - BodyData, - HttpRequestNodeAuthorization, - HttpRequestNodeBody, - HttpRequestNodeConfig, - HttpRequestNodeData, -) -from .node import HttpRequestNode - -__all__ = [ - "HTTP_REQUEST_CONFIG_FILTER_KEY", - "BodyData", - "HttpRequestNode", - "HttpRequestNodeAuthorization", - "HttpRequestNodeBody", - "HttpRequestNodeConfig", - "HttpRequestNodeData", - "build_http_request_config", - "resolve_http_request_config", -] diff --git a/api/dify_graph/nodes/http_request/config.py b/api/dify_graph/nodes/http_request/config.py deleted file mode 100644 index 53bf6c7ae4..0000000000 --- a/api/dify_graph/nodes/http_request/config.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Mapping - -from .entities import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNodeConfig - - -def build_http_request_config( - *, - max_connect_timeout: int = 10, - max_read_timeout: int = 600, - max_write_timeout: int = 600, - max_binary_size: int = 10 * 1024 * 1024, - max_text_size: int = 1 * 1024 * 1024, - ssl_verify: bool = True, - ssrf_default_max_retries: int = 3, -) -> HttpRequestNodeConfig: - return HttpRequestNodeConfig( - max_connect_timeout=max_connect_timeout, - max_read_timeout=max_read_timeout, - max_write_timeout=max_write_timeout, - max_binary_size=max_binary_size, - max_text_size=max_text_size, - ssl_verify=ssl_verify, - ssrf_default_max_retries=ssrf_default_max_retries, - ) - - -def resolve_http_request_config(filters: Mapping[str, object] | None) -> HttpRequestNodeConfig: - if not filters: - raise ValueError("http_request_config is required to build HTTP request default config") - config = filters.get(HTTP_REQUEST_CONFIG_FILTER_KEY) - if not isinstance(config, HttpRequestNodeConfig): - raise ValueError("http_request_config must be an HttpRequestNodeConfig instance") - return config diff --git a/api/dify_graph/nodes/http_request/entities.py b/api/dify_graph/nodes/http_request/entities.py deleted file mode 100644 index f594d58ae6..0000000000 --- a/api/dify_graph/nodes/http_request/entities.py +++ /dev/null @@ -1,241 +0,0 @@ -import mimetypes -from collections.abc import Sequence -from dataclasses import dataclass -from email.message import Message -from typing import Any, Literal - -import charset_normalizer -import httpx -from pydantic import BaseModel, Field, ValidationInfo, field_validator - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - -HTTP_REQUEST_CONFIG_FILTER_KEY = "http_request_config" - - -class HttpRequestNodeAuthorizationConfig(BaseModel): - type: Literal["basic", "bearer", "custom"] - api_key: str - header: str = "" - - -class HttpRequestNodeAuthorization(BaseModel): - type: Literal["no-auth", "api-key"] - config: HttpRequestNodeAuthorizationConfig | None = None - - @field_validator("config", mode="before") - @classmethod - def check_config(cls, v: HttpRequestNodeAuthorizationConfig, values: ValidationInfo): - """ - Check config, if type is no-auth, config should be None, otherwise it should be a dict. - """ - if values.data["type"] == "no-auth": - return None - else: - if not v or not isinstance(v, dict): - raise ValueError("config should be a dict") - - return v - - -class BodyData(BaseModel): - key: str = "" - type: Literal["file", "text"] - value: str = "" - file: Sequence[str] = Field(default_factory=list) - - -class HttpRequestNodeBody(BaseModel): - type: Literal["none", "form-data", "x-www-form-urlencoded", "raw-text", "json", "binary"] - data: Sequence[BodyData] = Field(default_factory=list) - - @field_validator("data", mode="before") - @classmethod - def check_data(cls, v: Any): - """For compatibility, if body is not set, return empty list.""" - if not v: - return [] - if isinstance(v, str): - return [BodyData(key="", type="text", value=v)] - return v - - -class HttpRequestNodeTimeout(BaseModel): - connect: int | None = None - read: int | None = None - write: int | None = None - - -@dataclass(frozen=True, slots=True) -class HttpRequestNodeConfig: - max_connect_timeout: int - max_read_timeout: int - max_write_timeout: int - max_binary_size: int - max_text_size: int - ssl_verify: bool - ssrf_default_max_retries: int - - def default_timeout(self) -> "HttpRequestNodeTimeout": - return HttpRequestNodeTimeout( - connect=self.max_connect_timeout, - read=self.max_read_timeout, - write=self.max_write_timeout, - ) - - -class HttpRequestNodeData(BaseNodeData): - """ - Code Node Data. - """ - - type: NodeType = BuiltinNodeTypes.HTTP_REQUEST - method: Literal[ - "get", - "post", - "put", - "patch", - "delete", - "head", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - authorization: HttpRequestNodeAuthorization - headers: str - params: str - body: HttpRequestNodeBody | None = None - timeout: HttpRequestNodeTimeout | None = None - ssl_verify: bool | None = None - - -class Response: - headers: dict[str, str] - response: httpx.Response - _cached_text: str | None - - def __init__(self, response: httpx.Response): - self.response = response - self.headers = dict(response.headers) - self._cached_text = None - - @property - def is_file(self): - """ - Determine if the response contains a file by checking: - 1. Content-Disposition header (RFC 6266) - 2. Content characteristics - 3. MIME type analysis - """ - content_type = self.content_type.split(";")[0].strip().lower() - parsed_content_disposition = self.parsed_content_disposition - - # Check if it's explicitly marked as an attachment - if parsed_content_disposition: - disp_type = parsed_content_disposition.get_content_disposition() # Returns 'attachment', 'inline', or None - filename = parsed_content_disposition.get_filename() # Returns filename if present, None otherwise - if disp_type == "attachment" or filename is not None: - return True - - # For 'text/' types, only 'csv' should be downloaded as file - if content_type.startswith("text/") and "csv" not in content_type: - return False - - # For application types, try to detect if it's a text-based format - if content_type.startswith("application/"): - # Common text-based application types - if any( - text_type in content_type - for text_type in ("json", "xml", "javascript", "x-www-form-urlencoded", "yaml", "graphql") - ): - return False - - # Try to detect if content is text-based by sampling first few bytes - try: - # Sample first 1024 bytes for text detection - content_sample = self.response.content[:1024] - content_sample.decode("utf-8") - # If we can decode as UTF-8 and find common text patterns, likely not a file - text_markers = (b"{", b"[", b"<", b"function", b"var ", b"const ", b"let ") - if any(marker in content_sample for marker in text_markers): - return False - except UnicodeDecodeError: - # If we can't decode as UTF-8, likely a binary file - return True - - # For other types, use MIME type analysis - main_type, _ = mimetypes.guess_type("dummy" + (mimetypes.guess_extension(content_type) or "")) - if main_type: - return main_type.split("/")[0] in ("application", "image", "audio", "video") - - # For unknown types, check if it's a media type - return any(media_type in content_type for media_type in ("image/", "audio/", "video/")) - - @property - def content_type(self) -> str: - return self.headers.get("content-type", "") - - @property - def text(self) -> str: - """ - Get response text with robust encoding detection. - - Uses charset_normalizer for better encoding detection than httpx's default, - which helps handle Chinese and other non-ASCII characters properly. - """ - # Check cache first - if hasattr(self, "_cached_text") and self._cached_text is not None: - return self._cached_text - - # Try charset_normalizer for robust encoding detection first - detected_encoding = charset_normalizer.from_bytes(self.response.content).best() - if detected_encoding and detected_encoding.encoding: - try: - text = self.response.content.decode(detected_encoding.encoding) - self._cached_text = text - return text - except (UnicodeDecodeError, TypeError, LookupError): - # Fallback to httpx's encoding detection if charset_normalizer fails - pass - - # Fallback to httpx's built-in encoding detection - text = self.response.text - self._cached_text = text - return text - - @property - def content(self) -> bytes: - return self.response.content - - @property - def status_code(self) -> int: - return self.response.status_code - - @property - def size(self) -> int: - return len(self.content) - - @property - def readable_size(self) -> str: - if self.size < 1024: - return f"{self.size} bytes" - elif self.size < 1024 * 1024: - return f"{(self.size / 1024):.2f} KB" - else: - return f"{(self.size / 1024 / 1024):.2f} MB" - - @property - def parsed_content_disposition(self) -> Message | None: - content_disposition = self.headers.get("content-disposition", "") - if content_disposition: - msg = Message() - msg["content-disposition"] = content_disposition - return msg - return None diff --git a/api/dify_graph/nodes/http_request/exc.py b/api/dify_graph/nodes/http_request/exc.py deleted file mode 100644 index 46613c9e86..0000000000 --- a/api/dify_graph/nodes/http_request/exc.py +++ /dev/null @@ -1,26 +0,0 @@ -class HttpRequestNodeError(ValueError): - """Custom error for HTTP request node.""" - - -class AuthorizationConfigError(HttpRequestNodeError): - """Raised when authorization config is missing or invalid.""" - - -class FileFetchError(HttpRequestNodeError): - """Raised when a file cannot be fetched.""" - - -class InvalidHttpMethodError(HttpRequestNodeError): - """Raised when an invalid HTTP method is used.""" - - -class ResponseSizeError(HttpRequestNodeError): - """Raised when the response size exceeds the allowed threshold.""" - - -class RequestBodyError(HttpRequestNodeError): - """Raised when the request body is invalid.""" - - -class InvalidURLError(HttpRequestNodeError): - """Raised when the URL is invalid.""" diff --git a/api/dify_graph/nodes/http_request/executor.py b/api/dify_graph/nodes/http_request/executor.py deleted file mode 100644 index 892b0fc688..0000000000 --- a/api/dify_graph/nodes/http_request/executor.py +++ /dev/null @@ -1,488 +0,0 @@ -import base64 -import json -import secrets -import string -from collections.abc import Callable, Mapping -from copy import deepcopy -from typing import Any, Literal -from urllib.parse import urlencode, urlparse - -import httpx -from json_repair import repair_json - -from dify_graph.file.enums import FileTransferMethod -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ArrayFileSegment, FileSegment - -from ..protocols import FileManagerProtocol, HttpClientProtocol -from .entities import ( - HttpRequestNodeAuthorization, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import ( - AuthorizationConfigError, - FileFetchError, - HttpRequestNodeError, - InvalidHttpMethodError, - InvalidURLError, - RequestBodyError, - ResponseSizeError, -) - -BODY_TYPE_TO_CONTENT_TYPE = { - "json": "application/json", - "x-www-form-urlencoded": "application/x-www-form-urlencoded", - "form-data": "multipart/form-data", - "raw-text": "text/plain", -} - - -class Executor: - method: Literal[ - "get", - "head", - "post", - "put", - "delete", - "patch", - "options", - "GET", - "POST", - "PUT", - "PATCH", - "DELETE", - "HEAD", - "OPTIONS", - ] - url: str - params: list[tuple[str, str]] | None - content: str | bytes | None - data: Mapping[str, Any] | None - files: list[tuple[str, tuple[str | None, bytes, str]]] | None - json: Any - headers: dict[str, str] - auth: HttpRequestNodeAuthorization - timeout: HttpRequestNodeTimeout - max_retries: int - - boundary: str - - def __init__( - self, - *, - node_data: HttpRequestNodeData, - timeout: HttpRequestNodeTimeout, - variable_pool: VariablePool, - http_request_config: HttpRequestNodeConfig, - max_retries: int | None = None, - ssl_verify: bool | None = None, - http_client: HttpClientProtocol, - file_manager: FileManagerProtocol, - ): - self._http_request_config = http_request_config - # If authorization API key is present, convert the API key using the variable pool - if node_data.authorization.type == "api-key": - if node_data.authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - node_data.authorization.config.api_key = variable_pool.convert_template( - node_data.authorization.config.api_key - ).text - # Validate that API key is not empty after template conversion - if not node_data.authorization.config.api_key or not node_data.authorization.config.api_key.strip(): - raise AuthorizationConfigError( - "API key is required for authorization but was empty. Please provide a valid API key." - ) - - self.url = node_data.url - self.method = node_data.method - self.auth = node_data.authorization - self.timeout = timeout - self.ssl_verify = ssl_verify if ssl_verify is not None else node_data.ssl_verify - if self.ssl_verify is None: - self.ssl_verify = self._http_request_config.ssl_verify - if not isinstance(self.ssl_verify, bool): - raise ValueError("ssl_verify must be a boolean") - self.params = None - self.headers = {} - self.content = None - self.files = None - self.data = None - self.json = None - self.max_retries = ( - max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries - ) - self._http_client = http_client - self._file_manager = file_manager - - # init template - self.variable_pool = variable_pool - self.node_data = node_data - self._initialize() - - def _initialize(self): - self._init_url() - self._init_params() - self._init_headers() - self._init_body() - - def _init_url(self): - self.url = self.variable_pool.convert_template(self.node_data.url).text - - # check if url is a valid URL - if not self.url: - raise InvalidURLError("url is required") - if not self.url.startswith(("http://", "https://")): - raise InvalidURLError("url should start with http:// or https://") - - def _init_params(self): - """ - Almost same as _init_headers(), difference: - 1. response a list tuple to support same key, like 'aa=1&aa=2' - 2. param value may have '\n', we need to splitlines then extract the variable value. - """ - result = [] - for line in self.node_data.params.splitlines(): - if not (line := line.strip()): - continue - - key, *value = line.split(":", 1) - if not (key := key.strip()): - continue - - value_str = value[0].strip() if value else "" - result.append( - (self.variable_pool.convert_template(key).text, self.variable_pool.convert_template(value_str).text) - ) - - if result: - self.params = result - - def _init_headers(self): - """ - Convert the header string of frontend to a dictionary. - - Each line in the header string represents a key-value pair. - Keys and values are separated by ':'. - Empty values are allowed. - - Examples: - 'aa:bb\n cc:dd' -> {'aa': 'bb', 'cc': 'dd'} - 'aa:\n cc:dd\n' -> {'aa': '', 'cc': 'dd'} - 'aa\n cc : dd' -> {'aa': '', 'cc': 'dd'} - - """ - headers = self.variable_pool.convert_template(self.node_data.headers).text - self.headers = { - key.strip(): (value[0].strip() if value else "") - for line in headers.splitlines() - if line.strip() - for key, *value in [line.split(":", 1)] - } - - def _init_body(self): - body = self.node_data.body - if body is not None: - data = body.data - match body.type: - case "none": - self.content = "" - case "raw-text": - if len(data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - self.content = self.variable_pool.convert_template(data[0].value).text - case "json": - if len(data) != 1: - raise RequestBodyError("json body type should have exactly one item") - json_string = self.variable_pool.convert_template(data[0].value).text - try: - repaired = repair_json(json_string) - json_object = json.loads(repaired, strict=False) - except json.JSONDecodeError as e: - raise RequestBodyError(f"Failed to parse JSON: {json_string}") from e - self.json = json_object - # self.json = self._parse_object_contains_variables(json_object) - case "binary": - if len(data) != 1: - raise RequestBodyError("binary body type should have exactly one item") - file_selector = data[0].file - file_variable = self.variable_pool.get_file(file_selector) - if file_variable is None: - raise FileFetchError(f"cannot fetch file with selector {file_selector}") - file = file_variable.value - self.content = self._file_manager.download(file) - case "x-www-form-urlencoded": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in data - } - self.data = form_data - case "form-data": - form_data = { - self.variable_pool.convert_template(item.key).text: self.variable_pool.convert_template( - item.value - ).text - for item in filter(lambda item: item.type == "text", data) - } - file_selectors = { - self.variable_pool.convert_template(item.key).text: item.file - for item in filter(lambda item: item.type == "file", data) - } - - # get files from file_selectors, add support for array file variables - files_list = [] - for key, selector in file_selectors.items(): - segment = self.variable_pool.get(selector) - if isinstance(segment, FileSegment): - files_list.append((key, [segment.value])) - elif isinstance(segment, ArrayFileSegment): - files_list.append((key, list(segment.value))) - - # get files from file_manager - files: dict[str, list[tuple[str | None, bytes, str]]] = {} - for key, files_in_segment in files_list: - for file in files_in_segment: - if file.related_id is not None or ( - file.transfer_method == FileTransferMethod.REMOTE_URL and file.remote_url is not None - ): - file_tuple = ( - file.filename, - self._file_manager.download(file), - file.mime_type or "application/octet-stream", - ) - if key not in files: - files[key] = [] - files[key].append(file_tuple) - - # convert files to list for httpx request - # If there are no actual files, we still need to force httpx to use `multipart/form-data`. - # This is achieved by inserting a harmless placeholder file that will be ignored by the server. - if not files: - self.files = [("__multipart_placeholder__", ("", b"", "application/octet-stream"))] - if files: - self.files = [] - for key, file_tuples in files.items(): - for file_tuple in file_tuples: - self.files.append((key, file_tuple)) - - self.data = form_data - - def _assembling_headers(self) -> dict[str, Any]: - authorization = deepcopy(self.auth) - headers = deepcopy(self.headers) or {} - if self.auth.type == "api-key": - if self.auth.config is None: - raise AuthorizationConfigError("self.authorization config is required") - if authorization.config is None: - raise AuthorizationConfigError("authorization config is required") - - if not authorization.config.header: - authorization.config.header = "Authorization" - - if self.auth.config.type == "bearer" and authorization.config.api_key: - headers[authorization.config.header] = f"Bearer {authorization.config.api_key}" - elif self.auth.config.type == "basic" and authorization.config.api_key: - credentials = authorization.config.api_key - if ":" in credentials: - encoded_credentials = base64.b64encode(credentials.encode("utf-8")).decode("utf-8") - else: - encoded_credentials = credentials - headers[authorization.config.header] = f"Basic {encoded_credentials}" - elif self.auth.config.type == "custom": - if authorization.config.header and authorization.config.api_key: - headers[authorization.config.header] = authorization.config.api_key - - # Handle Content-Type for multipart/form-data requests - # Fix for issue #23829: Missing boundary when using multipart/form-data - body = self.node_data.body - if body and body.type == "form-data": - # For multipart/form-data with files (including placeholder files), - # remove any manually set Content-Type header to let httpx handle - # For multipart/form-data, if any files are present (including placeholder files), - # we must remove any manually set Content-Type header. This is because httpx needs to - # automatically set the Content-Type and boundary for multipart encoding whenever files - # are included, even if they are placeholders, to avoid boundary issues and ensure correct - # file upload behaviour. Manually setting Content-Type can cause httpx to fail to set the - # boundary, resulting in invalid requests. - if self.files: - # Remove Content-Type if it was manually set to avoid boundary issues - headers = {k: v for k, v in headers.items() if k.lower() != "content-type"} - else: - # No files at all, set Content-Type manually - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = "multipart/form-data" - elif body and body.type in BODY_TYPE_TO_CONTENT_TYPE: - # Set Content-Type for other body types - if "content-type" not in (k.lower() for k in headers): - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - - return headers - - def _validate_and_parse_response(self, response: httpx.Response) -> Response: - executor_response = Response(response) - - threshold_size = ( - self._http_request_config.max_binary_size - if executor_response.is_file - else self._http_request_config.max_text_size - ) - if executor_response.size > threshold_size: - raise ResponseSizeError( - f"{'File' if executor_response.is_file else 'Text'} size is too large," - f" max size is {threshold_size / 1024 / 1024:.2f} MB," - f" but current size is {executor_response.readable_size}." - ) - - return executor_response - - def _do_http_request(self, headers: dict[str, Any]) -> httpx.Response: - """ - do http request depending on api bundle - """ - _METHOD_MAP: dict[str, Callable[..., httpx.Response]] = { - "get": self._http_client.get, - "head": self._http_client.head, - "post": self._http_client.post, - "put": self._http_client.put, - "delete": self._http_client.delete, - "patch": self._http_client.patch, - } - method_lc = self.method.lower() - if method_lc not in _METHOD_MAP: - raise InvalidHttpMethodError(f"Invalid http method {self.method}") - - request_args: dict[str, Any] = { - "data": self.data, - "files": self.files, - "json": self.json, - "content": self.content, - "headers": headers, - "params": self.params, - "timeout": (self.timeout.connect, self.timeout.read, self.timeout.write), - "ssl_verify": self.ssl_verify, - "follow_redirects": True, - } - # request_args = {k: v for k, v in request_args.items() if v is not None} - try: - response = _METHOD_MAP[method_lc]( - url=self.url, - **request_args, - max_retries=self.max_retries, - ) - except self._http_client.max_retries_exceeded_error as e: - raise HttpRequestNodeError(f"Reached maximum retries for URL {self.url}") from e - except self._http_client.request_error as e: - raise HttpRequestNodeError(str(e)) from e - return response - - def invoke(self) -> Response: - # assemble headers - headers = self._assembling_headers() - # do http request - response = self._do_http_request(headers) - # validate response - return self._validate_and_parse_response(response) - - def to_log(self): - url_parts = urlparse(self.url) - path = url_parts.path or "/" - - # Add query parameters - if self.params: - query_string = urlencode(self.params) - path += f"?{query_string}" - elif url_parts.query: - path += f"?{url_parts.query}" - - raw = f"{self.method.upper()} {path} HTTP/1.1\r\n" - raw += f"Host: {url_parts.netloc}\r\n" - - headers = self._assembling_headers() - body = self.node_data.body - boundary = f"----WebKitFormBoundary{_generate_random_string(16)}" - if body: - if "content-type" not in (k.lower() for k in self.headers) and body.type in BODY_TYPE_TO_CONTENT_TYPE: - headers["Content-Type"] = BODY_TYPE_TO_CONTENT_TYPE[body.type] - if body.type == "form-data": - headers["Content-Type"] = f"multipart/form-data; boundary={boundary}" - for k, v in headers.items(): - if self.auth.type == "api-key": - authorization_header = "Authorization" - if self.auth.config and self.auth.config.header: - authorization_header = self.auth.config.header - if k.lower() == authorization_header.lower(): - raw += f"{k}: {'*' * len(v)}\r\n" - continue - raw += f"{k}: {v}\r\n" - - body_string = "" - # Only log actual files if present. - # '__multipart_placeholder__' is inserted to force multipart encoding but is not a real file. - # This prevents logging meaningless placeholder entries. - if self.files and not all(f[0] == "__multipart_placeholder__" for f in self.files): - for file_entry in self.files: - # file_entry should be (key, (filename, content, mime_type)), but handle edge cases - if len(file_entry) != 2 or len(file_entry[1]) < 2: - continue # skip malformed entries - key = file_entry[0] - content = file_entry[1][1] - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - # decode content safely - # Do not decode binary content; use a placeholder with file metadata instead. - # Includes filename, size, and MIME type for better logging context. - body_string += ( - f"\r\n" - ) - body_string += f"--{boundary}--\r\n" - elif self.node_data.body: - if self.content: - # If content is bytes, do not decode it; show a placeholder with size. - # Provides content size information for binary data without exposing the raw bytes. - if isinstance(self.content, bytes): - body_string = f"" - else: - body_string = self.content - elif self.data and self.node_data.body.type == "x-www-form-urlencoded": - body_string = urlencode(self.data) - elif self.data and self.node_data.body.type == "form-data": - for key, value in self.data.items(): - body_string += f"--{boundary}\r\n" - body_string += f'Content-Disposition: form-data; name="{key}"\r\n\r\n' - body_string += f"{value}\r\n" - body_string += f"--{boundary}--\r\n" - elif self.json: - body_string = json.dumps(self.json) - elif self.node_data.body.type == "raw-text": - if len(self.node_data.body.data) != 1: - raise RequestBodyError("raw-text body type should have exactly one item") - body_string = self.node_data.body.data[0].value - if body_string: - raw += f"Content-Length: {len(body_string)}\r\n" - raw += "\r\n" # Empty line between headers and body - raw += body_string - - return raw - - -def _generate_random_string(n: int) -> str: - """ - Generate a random string of lowercase ASCII letters. - - Args: - n (int): The length of the random string to generate. - - Returns: - str: A random string of lowercase ASCII letters with length n. - - Example: - >>> _generate_random_string(5) - 'abcde' - """ - return "".join(secrets.choice(string.ascii_lowercase) for _ in range(n)) diff --git a/api/dify_graph/nodes/http_request/node.py b/api/dify_graph/nodes/http_request/node.py deleted file mode 100644 index 3e5253d809..0000000000 --- a/api/dify_graph/nodes/http_request/node.py +++ /dev/null @@ -1,261 +0,0 @@ -import logging -import mimetypes -from collections.abc import Callable, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import variable_template_parser -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request.executor import Executor -from dify_graph.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayFileSegment -from factories import file_factory - -from .config import build_http_request_config, resolve_http_request_config -from .entities import ( - HTTP_REQUEST_CONFIG_FILTER_KEY, - HttpRequestNodeConfig, - HttpRequestNodeData, - HttpRequestNodeTimeout, - Response, -) -from .exc import HttpRequestNodeError, RequestBodyError - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - -class HttpRequestNode(Node[HttpRequestNodeData]): - node_type = BuiltinNodeTypes.HTTP_REQUEST - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - http_request_config: HttpRequestNodeConfig, - http_client: HttpClientProtocol, - tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], - file_manager: FileManagerProtocol, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - self._http_request_config = http_request_config - self._http_client = http_client - self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - if not filters or HTTP_REQUEST_CONFIG_FILTER_KEY not in filters: - http_request_config = build_http_request_config() - else: - http_request_config = resolve_http_request_config(filters) - default_timeout = http_request_config.default_timeout() - return { - "type": "http-request", - "config": { - "method": "get", - "authorization": { - "type": "no-auth", - }, - "body": {"type": "none"}, - "timeout": { - **default_timeout.model_dump(), - "max_connect_timeout": http_request_config.max_connect_timeout, - "max_read_timeout": http_request_config.max_read_timeout, - "max_write_timeout": http_request_config.max_write_timeout, - }, - "ssl_verify": http_request_config.ssl_verify, - }, - "retry_config": { - "max_retries": http_request_config.ssrf_default_max_retries, - "retry_interval": 0.5 * (2**2), - "retry_enabled": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - process_data = {} - try: - http_executor = Executor( - node_data=self.node_data, - timeout=self._get_request_timeout(self.node_data), - variable_pool=self.graph_runtime_state.variable_pool, - http_request_config=self._http_request_config, - # Must be 0 to disable executor-level retries, as the graph engine handles them. - # This is critical to prevent nested retries. - max_retries=0, - ssl_verify=self.node_data.ssl_verify, - http_client=self._http_client, - file_manager=self._file_manager, - ) - process_data["request"] = http_executor.to_log() - - response = http_executor.invoke() - files = self.extract_files(url=http_executor.url, response=response) - if not response.response.is_success and (self.error_strategy or self.retry): - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - error=f"Request failed with status code {response.status_code}", - error_type="HTTPResponseCodeError", - ) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={ - "status_code": response.status_code, - "body": response.text if not files.value else "", - "headers": response.headers, - "files": files, - }, - process_data={ - "request": http_executor.to_log(), - }, - ) - except HttpRequestNodeError as e: - logger.warning("http request node %s failed to run: %s", self._node_id, e) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - process_data=process_data, - error_type=type(e).__name__, - ) - - def _get_request_timeout(self, node_data: HttpRequestNodeData) -> HttpRequestNodeTimeout: - default_timeout = self._http_request_config.default_timeout() - timeout = node_data.timeout - if timeout is None: - return default_timeout - - return HttpRequestNodeTimeout( - connect=timeout.connect or default_timeout.connect, - read=timeout.read or default_timeout.read, - write=timeout.write or default_timeout.write, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HttpRequestNodeData, - ) -> Mapping[str, Sequence[str]]: - selectors: list[VariableSelector] = [] - selectors += variable_template_parser.extract_selectors_from_template(node_data.url) - selectors += variable_template_parser.extract_selectors_from_template(node_data.headers) - selectors += variable_template_parser.extract_selectors_from_template(node_data.params) - if node_data.body: - body_type = node_data.body.type - data = node_data.body.data - match body_type: - case "none": - pass - case "binary": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selector = data[0].file - selectors.append(VariableSelector(variable="#" + ".".join(selector) + "#", value_selector=selector)) - case "json" | "raw-text": - if len(data) != 1: - raise RequestBodyError("invalid body data, should have only one item") - selectors += variable_template_parser.extract_selectors_from_template(data[0].key) - selectors += variable_template_parser.extract_selectors_from_template(data[0].value) - case "x-www-form-urlencoded": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - selectors += variable_template_parser.extract_selectors_from_template(item.value) - case "form-data": - for item in data: - selectors += variable_template_parser.extract_selectors_from_template(item.key) - if item.type == "text": - selectors += variable_template_parser.extract_selectors_from_template(item.value) - elif item.type == "file": - selectors.append( - VariableSelector(variable="#" + ".".join(item.file) + "#", value_selector=item.file) - ) - - mapping = {} - for selector_iter in selectors: - mapping[node_id + "." + selector_iter.variable] = selector_iter.value_selector - - return mapping - - def extract_files(self, url: str, response: Response) -> ArrayFileSegment: - """ - Extract files from response by checking both Content-Type header and URL - """ - dify_ctx = self.require_dify_context() - files: list[File] = [] - is_file = response.is_file - content_type = response.content_type - content = response.content - parsed_content_disposition = response.parsed_content_disposition - content_disposition_type = None - - if not is_file: - return ArrayFileSegment(value=[]) - - if parsed_content_disposition: - content_disposition_filename = parsed_content_disposition.get_filename() - if content_disposition_filename: - # If filename is available from content-disposition, use it to guess the content type - content_disposition_type = mimetypes.guess_type(content_disposition_filename)[0] - - # Guess file extension from URL or Content-Type header - filename = url.split("?")[0].split("/")[-1] or "" - mime_type = ( - content_disposition_type or content_type or mimetypes.guess_type(filename)[0] or "application/octet-stream" - ) - tool_file_manager = self._tool_file_manager_factory() - - tool_file = tool_file_manager.create_file_by_raw( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - conversation_id=None, - file_binary=content, - mimetype=mime_type, - ) - - mapping = { - "tool_file_id": tool_file.id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=dify_ctx.tenant_id, - ) - files.append(file) - - return ArrayFileSegment(value=files) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/dify_graph/nodes/human_input/__init__.py b/api/dify_graph/nodes/human_input/__init__.py deleted file mode 100644 index 1789604577..0000000000 --- a/api/dify_graph/nodes/human_input/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -""" -Human Input node implementation. -""" diff --git a/api/dify_graph/nodes/human_input/entities.py b/api/dify_graph/nodes/human_input/entities.py deleted file mode 100644 index 2a33b4a0a8..0000000000 --- a/api/dify_graph/nodes/human_input/entities.py +++ /dev/null @@ -1,424 +0,0 @@ -""" -Human Input node entities. -""" - -import re -import uuid -from collections.abc import Mapping, Sequence -from datetime import datetime, timedelta -from typing import Annotated, Any, ClassVar, Literal, Self - -import bleach -import markdown -from pydantic import BaseModel, Field, field_validator, model_validator - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.runtime import VariablePool -from dify_graph.variables.consts import SELECTORS_LENGTH - -from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit - -_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}") - - -class _WebAppDeliveryConfig(BaseModel): - """Configuration for webapp delivery method.""" - - pass # Empty for webapp delivery - - -class MemberRecipient(BaseModel): - """Member recipient for email delivery.""" - - type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER - user_id: str - - -class ExternalRecipient(BaseModel): - """External recipient for email delivery.""" - - type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL - email: str - - -EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")] - - -class EmailRecipients(BaseModel): - """Email recipients configuration.""" - - # When true, recipients are the union of all workspace members and external items. - # Member items are ignored because they are already covered by the workspace scope. - # De-duplication is applied by email, with member recipients taking precedence. - whole_workspace: bool = False - items: list[EmailRecipient] = Field(default_factory=list) - - -class EmailDeliveryConfig(BaseModel): - """Configuration for email delivery method.""" - - URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}" - _SUBJECT_NEWLINE_PATTERN: ClassVar[re.Pattern[str]] = re.compile(r"[\r\n]+") - _ALLOWED_HTML_TAGS: ClassVar[list[str]] = [ - "a", - "blockquote", - "br", - "code", - "em", - "h1", - "h2", - "h3", - "h4", - "h5", - "h6", - "hr", - "li", - "ol", - "p", - "pre", - "strong", - "table", - "tbody", - "td", - "th", - "thead", - "tr", - "ul", - ] - _ALLOWED_HTML_ATTRIBUTES: ClassVar[dict[str, list[str]]] = { - "a": ["href", "title"], - "td": ["align"], - "th": ["align"], - } - _ALLOWED_PROTOCOLS: ClassVar[list[str]] = ["http", "https", "mailto"] - - recipients: EmailRecipients - - # the subject of email - subject: str - - # Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which - # represent the url to submit the form. - # - # It may also reference the output variable of the previous node with the syntax - # `{{#.#}}`. - body: str - debug_mode: bool = False - - def with_debug_recipient(self, user_id: str | None) -> "EmailDeliveryConfig": - if user_id is None: - debug_recipients = EmailRecipients(whole_workspace=False, items=[]) - return self.model_copy(update={"recipients": debug_recipients}) - debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)]) - return self.model_copy(update={"recipients": debug_recipients}) - - @classmethod - def replace_url_placeholder(cls, body: str, url: str | None) -> str: - """Replace the url placeholder with provided value.""" - return body.replace(cls.URL_PLACEHOLDER, url or "") - - @classmethod - def render_body_template( - cls, - *, - body: str, - url: str | None, - variable_pool: VariablePool | None = None, - ) -> str: - """Render email body by replacing placeholders with runtime values.""" - templated_body = cls.replace_url_placeholder(body, url) - if variable_pool is None: - return templated_body - return variable_pool.convert_template(templated_body).text - - @classmethod - def render_markdown_body(cls, body: str) -> str: - """Render markdown to safe HTML for email delivery.""" - sanitized_markdown = bleach.clean( - body, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - rendered_html = markdown.markdown( - sanitized_markdown, - extensions=["nl2br", "tables"], - extension_configs={"tables": {"use_align_attribute": True}}, - ) - return bleach.clean( - rendered_html, - tags=cls._ALLOWED_HTML_TAGS, - attributes=cls._ALLOWED_HTML_ATTRIBUTES, - protocols=cls._ALLOWED_PROTOCOLS, - strip=True, - strip_comments=True, - ) - - @classmethod - def sanitize_subject(cls, subject: str) -> str: - """Sanitize email subject to plain text and prevent CRLF injection.""" - sanitized_subject = bleach.clean( - subject, - tags=[], - attributes={}, - strip=True, - strip_comments=True, - ) - sanitized_subject = cls._SUBJECT_NEWLINE_PATTERN.sub(" ", sanitized_subject) - return " ".join(sanitized_subject.split()) - - -class _DeliveryMethodBase(BaseModel): - """Base delivery method configuration.""" - - enabled: bool = True - id: uuid.UUID = Field(default_factory=uuid.uuid4) - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - return () - - -class WebAppDeliveryMethod(_DeliveryMethodBase): - """Webapp delivery method configuration.""" - - type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP - # The config field is not used currently. - config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig) - - -class EmailDeliveryMethod(_DeliveryMethodBase): - """Email delivery method configuration.""" - - type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL - config: EmailDeliveryConfig - - def extract_variable_selectors(self) -> Sequence[Sequence[str]]: - variable_template_parser = VariableTemplateParser(template=self.config.body) - selectors: list[Sequence[str]] = [] - for variable_selector in variable_template_parser.extract_variable_selectors(): - value_selector = list(variable_selector.value_selector) - if len(value_selector) < SELECTORS_LENGTH: - continue - selectors.append(value_selector[:SELECTORS_LENGTH]) - return selectors - - -DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")] - - -def apply_debug_email_recipient( - method: DeliveryChannelConfig, - *, - enabled: bool, - user_id: str | None, -) -> DeliveryChannelConfig: - if not enabled: - return method - if not isinstance(method, EmailDeliveryMethod): - return method - if not method.config.debug_mode: - return method - debug_config = method.config.with_debug_recipient(user_id) - return method.model_copy(update={"config": debug_config}) - - -class FormInputDefault(BaseModel): - """Default configuration for form inputs.""" - - # NOTE: Ideally, a discriminated union would be used to model - # FormInputDefault. However, the UI requires preserving the previous - # value when switching between `VARIABLE` and `CONSTANT` types. This - # necessitates retaining all fields, making a discriminated union unsuitable. - - type: PlaceholderType - - # The selector of default variable, used when `type` is `VARIABLE`. - selector: Sequence[str] = Field(default_factory=tuple) # - - # The value of the default, used when `type` is `CONSTANT`. - # TODO: How should we express JSON values? - value: str = "" - - @model_validator(mode="after") - def _validate_selector(self) -> Self: - if self.type == PlaceholderType.CONSTANT: - return self - if len(self.selector) < SELECTORS_LENGTH: - raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}") - return self - - -class FormInput(BaseModel): - """Form input definition.""" - - type: FormInputType - output_variable_name: str - default: FormInputDefault | None = None - - -_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$") - - -class UserAction(BaseModel): - """User action configuration.""" - - # id is the identifier for this action. - # It also serves as the identifiers of output handle. - # - # The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.) - id: str = Field(max_length=20) - title: str = Field(max_length=20) - button_style: ButtonStyle = ButtonStyle.DEFAULT - - @field_validator("id") - @classmethod - def _validate_id(cls, value: str) -> str: - if not _IDENTIFIER_PATTERN.match(value): - raise ValueError( - f"'{value}' is not a valid identifier. It must start with a letter or underscore, " - f"and contain only letters, numbers, or underscores." - ) - return value - - -class HumanInputNodeData(BaseNodeData): - """Human Input node data.""" - - type: NodeType = BuiltinNodeTypes.HUMAN_INPUT - delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list) - form_content: str = "" - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - timeout: int = 36 - timeout_unit: TimeoutUnit = TimeoutUnit.HOUR - - @field_validator("inputs") - @classmethod - def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]: - seen_names: set[str] = set() - for form_input in inputs: - name = form_input.output_variable_name - if name in seen_names: - raise ValueError(f"duplicated output_variable_name '{name}' in inputs") - seen_names.add(name) - return inputs - - @field_validator("user_actions") - @classmethod - def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]: - seen_ids: set[str] = set() - for action in user_actions: - action_id = action.id - if action_id in seen_ids: - raise ValueError(f"duplicated user action id '{action_id}'") - seen_ids.add(action_id) - return user_actions - - def is_webapp_enabled(self) -> bool: - for dm in self.delivery_methods: - if not dm.enabled: - continue - if dm.type == DeliveryMethodType.WEBAPP: - return True - return False - - def expiration_time(self, start_time: datetime) -> datetime: - if self.timeout_unit == TimeoutUnit.HOUR: - return start_time + timedelta(hours=self.timeout) - elif self.timeout_unit == TimeoutUnit.DAY: - return start_time + timedelta(days=self.timeout) - else: - raise AssertionError("unknown timeout unit.") - - def outputs_field_names(self) -> Sequence[str]: - field_names = [] - for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content): - field_names.append(match.group("field_name")) - return field_names - - def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]: - variable_mappings: dict[str, Sequence[str]] = {} - - def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None: - for selector in selectors: - if len(selector) < SELECTORS_LENGTH: - continue - qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#" - variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH]) - - form_template_parser = VariableTemplateParser(template=self.form_content) - _add_variable_selectors( - [selector.value_selector for selector in form_template_parser.extract_variable_selectors()] - ) - for delivery_method in self.delivery_methods: - if not delivery_method.enabled: - continue - _add_variable_selectors(delivery_method.extract_variable_selectors()) - - for input in self.inputs: - default_value = input.default - if default_value is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - default_value_key = ".".join(default_value.selector) - qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#" - variable_mappings[qualified_variable_mapping_key] = default_value.selector - - return variable_mappings - - def find_action_text(self, action_id: str) -> str: - """ - Resolve action display text by id. - """ - for action in self.user_actions: - if action.id == action_id: - return action.title - return action_id - - -class FormDefinition(BaseModel): - form_content: str - inputs: list[FormInput] = Field(default_factory=list) - user_actions: list[UserAction] = Field(default_factory=list) - rendered_content: str - expiration_time: datetime - - # this is used to store the resolved default values - default_values: dict[str, Any] = Field(default_factory=dict) - - # node_title records the title of the HumanInput node. - node_title: str | None = None - - # display_in_ui controls whether the form should be displayed in UI surfaces. - display_in_ui: bool | None = None - - -class HumanInputSubmissionValidationError(ValueError): - pass - - -def validate_human_input_submission( - *, - inputs: Sequence[FormInput], - user_actions: Sequence[UserAction], - selected_action_id: str, - form_data: Mapping[str, Any], -) -> None: - available_actions = {action.id for action in user_actions} - if selected_action_id not in available_actions: - raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}") - - provided_inputs = set(form_data.keys()) - missing_inputs = [ - form_input.output_variable_name - for form_input in inputs - if form_input.output_variable_name not in provided_inputs - ] - - if missing_inputs: - missing_list = ", ".join(missing_inputs) - raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}") diff --git a/api/dify_graph/nodes/human_input/enums.py b/api/dify_graph/nodes/human_input/enums.py deleted file mode 100644 index da85728828..0000000000 --- a/api/dify_graph/nodes/human_input/enums.py +++ /dev/null @@ -1,72 +0,0 @@ -import enum - - -class HumanInputFormStatus(enum.StrEnum): - """Status of a human input form.""" - - # Awaiting submission from any recipient. Forms stay in this state until - # submitted or a timeout rule applies. - WAITING = enum.auto() - # Global timeout reached. The workflow run is stopped and will not resume. - # This is distinct from node-level timeout. - EXPIRED = enum.auto() - # Submitted by a recipient; form data is available and execution resumes - # along the selected action edge. - SUBMITTED = enum.auto() - # Node-level timeout reached. The human input node should emit a timeout - # event and the workflow should resume along the timeout edge. - TIMEOUT = enum.auto() - - -class HumanInputFormKind(enum.StrEnum): - """Kind of a human input form.""" - - RUNTIME = enum.auto() # Form created during workflow execution. - DELIVERY_TEST = enum.auto() # Form created for delivery tests. - - -class DeliveryMethodType(enum.StrEnum): - """Delivery method types for human input forms.""" - - # WEBAPP controls whether the form is delivered to the web app. It not only controls - # the standalone web app, but also controls the installed apps in the console. - WEBAPP = enum.auto() - - EMAIL = enum.auto() - - -class ButtonStyle(enum.StrEnum): - """Button styles for user actions.""" - - PRIMARY = enum.auto() - DEFAULT = enum.auto() - ACCENT = enum.auto() - GHOST = enum.auto() - - -class TimeoutUnit(enum.StrEnum): - """Timeout unit for form expiration.""" - - HOUR = enum.auto() - DAY = enum.auto() - - -class FormInputType(enum.StrEnum): - """Form input types.""" - - TEXT_INPUT = enum.auto() - PARAGRAPH = enum.auto() - - -class PlaceholderType(enum.StrEnum): - """Default value types for form inputs.""" - - VARIABLE = enum.auto() - CONSTANT = enum.auto() - - -class EmailRecipientType(enum.StrEnum): - """Email recipient types.""" - - MEMBER = enum.auto() - EXTERNAL = enum.auto() diff --git a/api/dify_graph/nodes/human_input/human_input_node.py b/api/dify_graph/nodes/human_input/human_input_node.py deleted file mode 100644 index 794e33d92e..0000000000 --- a/api/dify_graph/nodes/human_input/human_input_node.py +++ /dev/null @@ -1,361 +0,0 @@ -import json -import logging -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import ( - HumanInputFormFilledEvent, - HumanInputFormTimeoutEvent, - NodeRunResult, - PauseRequestedEvent, -) -from dify_graph.node_events.base import NodeEventBase -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter -from libs.datetime_utils import naive_utc_now - -from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient -from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType - -if TYPE_CHECKING: - from dify_graph.entities.graph_init_params import GraphInitParams - from dify_graph.runtime.graph_runtime_state import GraphRuntimeState - - -_SELECTED_BRANCH_KEY = "selected_branch" -_INVOKE_FROM_DEBUGGER = "debugger" -_INVOKE_FROM_EXPLORE = "explore" - - -logger = logging.getLogger(__name__) - - -class HumanInputNode(Node[HumanInputNodeData]): - node_type = BuiltinNodeTypes.HUMAN_INPUT - execution_type = NodeExecutionType.BRANCH - - _BRANCH_SELECTION_KEYS: tuple[str, ...] = ( - "edge_source_handle", - "edgeSourceHandle", - "source_handle", - _SELECTED_BRANCH_KEY, - "selectedBranch", - "branch", - "branch_id", - "branchId", - "handle", - ) - - _node_data: HumanInputNodeData - _form_repository: HumanInputFormRepository - _OUTPUT_FIELD_ACTION_ID = "__action_id" - _OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content" - _TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout" - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - form_repository: HumanInputFormRepository, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._form_repository = form_repository - - @classmethod - def version(cls) -> str: - return "1" - - def _resolve_branch_selection(self) -> str | None: - """Determine the branch handle selected by human input if available.""" - - variable_pool = self.graph_runtime_state.variable_pool - - for key in self._BRANCH_SELECTION_KEYS: - handle = self._extract_branch_handle(variable_pool.get((self.id, key))) - if handle: - return handle - - default_values = self.node_data.default_value_dict - for key in self._BRANCH_SELECTION_KEYS: - handle = self._normalize_branch_value(default_values.get(key)) - if handle: - return handle - - return None - - @staticmethod - def _extract_branch_handle(segment: Any) -> str | None: - if segment is None: - return None - - candidate = getattr(segment, "to_object", None) - raw_value = candidate() if callable(candidate) else getattr(segment, "value", None) - if raw_value is None: - return None - - return HumanInputNode._normalize_branch_value(raw_value) - - @staticmethod - def _normalize_branch_value(value: Any) -> str | None: - if value is None: - return None - - if isinstance(value, str): - stripped = value.strip() - return stripped or None - - if isinstance(value, Mapping): - for key in ("handle", "edge_source_handle", "edgeSourceHandle", "branch", "id", "value"): - candidate = value.get(key) - if isinstance(candidate, str) and candidate: - return candidate - - return None - - @property - def _workflow_execution_id(self) -> str: - workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id - assert workflow_exec_id is not None - return workflow_exec_id - - def _form_to_pause_event(self, form_entity: HumanInputFormEntity): - required_event = self._human_input_required_event(form_entity) - pause_requested_event = PauseRequestedEvent(reason=required_event) - return pause_requested_event - - def resolve_default_values(self) -> Mapping[str, Any]: - variable_pool = self.graph_runtime_state.variable_pool - resolved_defaults: dict[str, Any] = {} - for input in self._node_data.inputs: - if (default_value := input.default) is None: - continue - if default_value.type == PlaceholderType.CONSTANT: - continue - resolved_value = variable_pool.get(default_value.selector) - if resolved_value is None: - # TODO: How should we handle this? - continue - resolved_defaults[input.output_variable_name] = ( - WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value) - ) - - return resolved_defaults - - def _should_require_console_recipient(self) -> bool: - invoke_from = self._invoke_from_value() - if invoke_from == _INVOKE_FROM_DEBUGGER: - return True - if invoke_from == _INVOKE_FROM_EXPLORE: - return self._node_data.is_webapp_enabled() - return False - - def _display_in_ui(self) -> bool: - if self._invoke_from_value() == _INVOKE_FROM_DEBUGGER: - return True - return self._node_data.is_webapp_enabled() - - def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]: - dify_ctx = self.require_dify_context() - invoke_from = self._invoke_from_value() - enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled] - if invoke_from in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE}: - enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP] - return [ - apply_debug_email_recipient( - method, - enabled=invoke_from == _INVOKE_FROM_DEBUGGER, - user_id=dify_ctx.user_id, - ) - for method in enabled_methods - ] - - def _invoke_from_value(self) -> str: - invoke_from = self.require_dify_context().invoke_from - if isinstance(invoke_from, str): - return invoke_from - return str(getattr(invoke_from, "value", invoke_from)) - - def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired: - node_data = self._node_data - resolved_default_values = self.resolve_default_values() - display_in_ui = self._display_in_ui() - form_token = form_entity.web_app_token - if display_in_ui and form_token is None: - raise AssertionError("Form token should be available for UI execution.") - return HumanInputRequired( - form_id=form_entity.id, - form_content=form_entity.rendered_content, - inputs=node_data.inputs, - actions=node_data.user_actions, - display_in_ui=display_in_ui, - node_id=self.id, - node_title=node_data.title, - form_token=form_token, - resolved_default_values=resolved_default_values, - ) - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Execute the human input node. - - This method will: - 1. Generate a unique form ID - 2. Create form content with variable substitution - 3. Create form in database - 4. Send form via configured delivery methods - 5. Suspend workflow execution - 6. Wait for form submission to resume - """ - repo = self._form_repository - form = repo.get_form(self._workflow_execution_id, self.id) - dify_ctx = self.require_dify_context() - if form is None: - display_in_ui = self._display_in_ui() - params = FormCreateParams( - app_id=dify_ctx.app_id, - workflow_execution_id=self._workflow_execution_id, - node_id=self.id, - form_config=self._node_data, - rendered_content=self.render_form_content_before_submission(), - delivery_methods=self._effective_delivery_methods(), - display_in_ui=display_in_ui, - resolved_default_values=self.resolve_default_values(), - console_recipient_required=self._should_require_console_recipient(), - console_creator_account_id=( - dify_ctx.user_id - if self._invoke_from_value() in {_INVOKE_FROM_DEBUGGER, _INVOKE_FROM_EXPLORE} - else None - ), - backstage_recipient_required=True, - ) - form_entity = self._form_repository.create_form(params) - # Create human input required event - - logger.info( - "Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s", - self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id, - self.id, - form_entity.id, - ) - yield self._form_to_pause_event(form_entity) - return - - if ( - form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED} - or form.expiration_time <= naive_utc_now() - ): - yield HumanInputFormTimeoutEvent( - node_title=self._node_data.title, - expiration_time=form.expiration_time, - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={self._OUTPUT_FIELD_ACTION_ID: ""}, - edge_source_handle=self._TIMEOUT_HANDLE, - ) - ) - return - - if not form.submitted: - yield self._form_to_pause_event(form) - return - - selected_action_id = form.selected_action_id - if selected_action_id is None: - raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}") - submitted_data = form.submitted_data or {} - outputs: dict[str, Any] = dict(submitted_data) - outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id - rendered_content = self.render_form_content_with_outputs( - form.rendered_content, - outputs, - self._node_data.outputs_field_names(), - ) - outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content - - action_text = self._node_data.find_action_text(selected_action_id) - - yield HumanInputFormFilledEvent( - node_title=self._node_data.title, - rendered_content=rendered_content, - action_id=selected_action_id, - action_text=action_text, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs, - edge_source_handle=selected_action_id, - ) - ) - - def render_form_content_before_submission(self) -> str: - """ - Process form content by substituting variables. - - This method should: - 1. Parse the form_content markdown - 2. Substitute {{#node_name.var_name#}} with actual values - 3. Keep {{#$output.field_name#}} placeholders for form inputs - """ - rendered_form_content = self.graph_runtime_state.variable_pool.convert_template( - self._node_data.form_content, - ) - return rendered_form_content.markdown - - @staticmethod - def render_form_content_with_outputs( - form_content: str, - outputs: Mapping[str, Any], - field_names: Sequence[str], - ) -> str: - """ - Replace {{#$output.xxx#}} placeholders with submitted values. - """ - rendered_content = form_content - for field_name in field_names: - placeholder = "{{#$output." + field_name + "#}}" - value = outputs.get(field_name) - if value is None: - replacement = "" - elif isinstance(value, (dict, list)): - replacement = json.dumps(value, ensure_ascii=False) - else: - replacement = str(value) - rendered_content = rendered_content.replace(placeholder, replacement) - return rendered_content - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: HumanInputNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selectors referenced in form content and input default values. - - This method should parse: - 1. Variables referenced in form_content ({{#node_name.var_name#}}) - 2. Variables referenced in input default values - """ - return node_data.extract_variable_selector_to_variable_mapping(node_id) diff --git a/api/dify_graph/nodes/if_else/__init__.py b/api/dify_graph/nodes/if_else/__init__.py deleted file mode 100644 index afa0e8112c..0000000000 --- a/api/dify_graph/nodes/if_else/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .if_else_node import IfElseNode - -__all__ = ["IfElseNode"] diff --git a/api/dify_graph/nodes/if_else/entities.py b/api/dify_graph/nodes/if_else/entities.py deleted file mode 100644 index ff09f3c023..0000000000 --- a/api/dify_graph/nodes/if_else/entities.py +++ /dev/null @@ -1,29 +0,0 @@ -from typing import Literal - -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.utils.condition.entities import Condition - - -class IfElseNodeData(BaseNodeData): - """ - If Else Node Data. - """ - - type: NodeType = BuiltinNodeTypes.IF_ELSE - - class Case(BaseModel): - """ - Case entity representing a single logical condition group - """ - - case_id: str - logical_operator: Literal["and", "or"] - conditions: list[Condition] - - logical_operator: Literal["and", "or"] | None = "and" - conditions: list[Condition] | None = Field(default=None, deprecated=True) - - cases: list[Case] | None = None diff --git a/api/dify_graph/nodes/if_else/if_else_node.py b/api/dify_graph/nodes/if_else/if_else_node.py deleted file mode 100644 index 7c0370e48c..0000000000 --- a/api/dify_graph/nodes/if_else/if_else_node.py +++ /dev/null @@ -1,124 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from typing_extensions import deprecated - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.runtime import VariablePool -from dify_graph.utils.condition.entities import Condition -from dify_graph.utils.condition.processor import ConditionProcessor - - -class IfElseNode(Node[IfElseNodeData]): - node_type = BuiltinNodeTypes.IF_ELSE - execution_type = NodeExecutionType.BRANCH - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run node - :return: - """ - node_inputs: dict[str, Sequence[Mapping[str, Any]]] = {"conditions": []} - - process_data: dict[str, list] = {"condition_results": []} - - input_conditions: Sequence[Mapping[str, Any]] = [] - final_result = False - selected_case_id = "false" - condition_processor = ConditionProcessor() - try: - # Check if the new cases structure is used - if self.node_data.cases: - for case in self.node_data.cases: - input_conditions, group_result, final_result = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=case.conditions, - operator=case.logical_operator, - ) - - process_data["condition_results"].append( - { - "group": case.model_dump(), - "results": group_result, - "final_result": final_result, - } - ) - - # Break if a case passes (logical short-circuit) - if final_result: - selected_case_id = case.case_id # Capture the ID of the passing case - break - - else: - # TODO: Update database then remove this - # Fallback to old structure if cases are not defined - input_conditions, group_result, final_result = _should_not_use_old_function( # pyright: ignore [reportDeprecated] - condition_processor=condition_processor, - variable_pool=self.graph_runtime_state.variable_pool, - conditions=self.node_data.conditions or [], - operator=self.node_data.logical_operator or "and", - ) - - selected_case_id = "true" if final_result else "false" - - process_data["condition_results"].append( - {"group": "default", "results": group_result, "final_result": final_result} - ) - - node_inputs["conditions"] = input_conditions - - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, inputs=node_inputs, process_data=process_data, error=str(e) - ) - - outputs = {"result": final_result, "selected_case_id": selected_case_id} - - data = NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - edge_source_handle=selected_case_id or "false", # Use case ID or 'default' - outputs=outputs, - ) - - return data - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IfElseNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, list[str]] = {} - _ = graph_config # Explicitly mark as unused - for case in node_data.cases or []: - for condition in case.conditions: - key = f"{node_id}.#{'.'.join(condition.variable_selector)}#" - var_mapping[key] = condition.variable_selector - - return var_mapping - - -@deprecated("This function is deprecated. You should use the new cases structure.") -def _should_not_use_old_function( - *, - condition_processor: ConditionProcessor, - variable_pool: VariablePool, - conditions: list[Condition], - operator: Literal["and", "or"], -): - return condition_processor.process_conditions( - variable_pool=variable_pool, - conditions=conditions, - operator=operator, - ) diff --git a/api/dify_graph/nodes/iteration/__init__.py b/api/dify_graph/nodes/iteration/__init__.py deleted file mode 100644 index 5bb87aaffa..0000000000 --- a/api/dify_graph/nodes/iteration/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -from .entities import IterationNodeData -from .iteration_node import IterationNode -from .iteration_start_node import IterationStartNode - -__all__ = ["IterationNode", "IterationNodeData", "IterationStartNode"] diff --git a/api/dify_graph/nodes/iteration/entities.py b/api/dify_graph/nodes/iteration/entities.py deleted file mode 100644 index 58fd112b12..0000000000 --- a/api/dify_graph/nodes/iteration/entities.py +++ /dev/null @@ -1,67 +0,0 @@ -from enum import StrEnum -from typing import Any - -from pydantic import Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base import BaseIterationNodeData, BaseIterationState - - -class ErrorHandleMode(StrEnum): - TERMINATED = "terminated" - CONTINUE_ON_ERROR = "continue-on-error" - REMOVE_ABNORMAL_OUTPUT = "remove-abnormal-output" - - -class IterationNodeData(BaseIterationNodeData): - """ - Iteration Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION - parent_loop_id: str | None = None # redundant field, not used currently - iterator_selector: list[str] # variable selector - output_selector: list[str] # output selector - is_parallel: bool = False # open the parallel mode or not - parallel_nums: int = 10 # the numbers of parallel - error_handle_mode: ErrorHandleMode = ErrorHandleMode.TERMINATED # how to handle the error - flatten_output: bool = True # whether to flatten the output array if all elements are lists - - -class IterationStartNodeData(BaseNodeData): - """ - Iteration Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.ITERATION_START - - -class IterationState(BaseIterationState): - """ - Iteration State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseIterationState.MetaData): - """ - Data. - """ - - iterator_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output diff --git a/api/dify_graph/nodes/iteration/exc.py b/api/dify_graph/nodes/iteration/exc.py deleted file mode 100644 index d9947e09bc..0000000000 --- a/api/dify_graph/nodes/iteration/exc.py +++ /dev/null @@ -1,22 +0,0 @@ -class IterationNodeError(ValueError): - """Base class for iteration node errors.""" - - -class IteratorVariableNotFoundError(IterationNodeError): - """Raised when the iterator variable is not found.""" - - -class InvalidIteratorValueError(IterationNodeError): - """Raised when the iterator value is invalid.""" - - -class StartNodeIdNotFoundError(IterationNodeError): - """Raised when the start node ID is not found.""" - - -class IterationGraphNotFoundError(IterationNodeError): - """Raised when the iteration graph is not found.""" - - -class IterationIndexNotFoundError(IterationNodeError): - """Raised when the iteration index is not found.""" diff --git a/api/dify_graph/nodes/iteration/iteration_node.py b/api/dify_graph/nodes/iteration/iteration_node.py deleted file mode 100644 index 033ec8672f..0000000000 --- a/api/dify_graph/nodes/iteration/iteration_node.py +++ /dev/null @@ -1,626 +0,0 @@ -import logging -from collections.abc import Generator, Mapping, Sequence -from concurrent.futures import Future, ThreadPoolExecutor, as_completed -from datetime import UTC, datetime -from typing import TYPE_CHECKING, Any, NewType, cast - -from typing_extensions import TypeIs - -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph_events import ( - GraphNodeEventBase, - GraphRunFailedEvent, - GraphRunPartialSucceededEvent, - GraphRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import ( - IterationFailedEvent, - IterationNextEvent, - IterationStartedEvent, - IterationSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.runtime import VariablePool -from dify_graph.variables import IntegerVariable, NoneSegment -from dify_graph.variables.segments import ArrayAnySegment, ArraySegment -from dify_graph.variables.variables import Variable -from libs.datetime_utils import naive_utc_now - -from .exc import ( - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) - -if TYPE_CHECKING: - from dify_graph.context import IExecutionContext - from dify_graph.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) - -EmptyArraySegment = NewType("EmptyArraySegment", ArraySegment) - - -class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]): - """ - Iteration Node. - """ - - node_type = BuiltinNodeTypes.ITERATION - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "iteration", - "config": { - "is_parallel": False, - "parallel_nums": 10, - "error_handle_mode": ErrorHandleMode.TERMINATED, - "flatten_output": True, - }, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: # type: ignore - variable = self._get_iterator_variable() - - if self._is_empty_iteration(variable): - yield from self._handle_empty_iteration(variable) - return - - iterator_list_value = self._validate_and_get_iterator_list(variable) - inputs = {"iterator_selector": iterator_list_value} - - self._validate_start_node() - - started_at = naive_utc_now() - iter_run_map: dict[str, float] = {} - outputs: list[object] = [] - usage_accumulator = [LLMUsage.empty_usage()] - - yield IterationStartedEvent( - start_at=started_at, - inputs=inputs, - metadata={"iteration_length": len(iterator_list_value)}, - ) - - try: - yield from self._execute_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_success( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - ) - except IterationNodeError as e: - self._accumulate_usage(usage_accumulator[0]) - yield from self._handle_iteration_failure( - started_at=started_at, - inputs=inputs, - outputs=outputs, - iterator_list_value=iterator_list_value, - iter_run_map=iter_run_map, - usage=usage_accumulator[0], - error=e, - ) - - def _get_iterator_variable(self) -> ArraySegment | NoneSegment: - variable = self.graph_runtime_state.variable_pool.get(self.node_data.iterator_selector) - - if not variable: - raise IteratorVariableNotFoundError(f"iterator variable {self.node_data.iterator_selector} not found") - - if not isinstance(variable, ArraySegment) and not isinstance(variable, NoneSegment): - raise InvalidIteratorValueError(f"invalid iterator value: {variable}, please provide a list.") - - return variable - - def _is_empty_iteration(self, variable: ArraySegment | NoneSegment) -> TypeIs[NoneSegment | EmptyArraySegment]: - return isinstance(variable, NoneSegment) or len(variable.value) == 0 - - def _handle_empty_iteration(self, variable: ArraySegment | NoneSegment) -> Generator[NodeEventBase, None, None]: - # Try our best to preserve the type information. - if isinstance(variable, ArraySegment): - output = variable.model_copy(update={"value": []}) - else: - output = ArrayAnySegment(value=[]) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - # TODO(QuantumGhost): is it possible to compute the type of `output` - # from graph definition? - outputs={"output": output}, - ) - ) - - def _validate_and_get_iterator_list(self, variable: ArraySegment) -> Sequence[object]: - iterator_list_value = variable.to_object() - - if not isinstance(iterator_list_value, list): - raise InvalidIteratorValueError(f"Invalid iterator value: {iterator_list_value}, please provide a list.") - - return cast(list[object], iterator_list_value) - - def _validate_start_node(self) -> None: - if not self.node_data.start_node_id: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - def _execute_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - if self.node_data.is_parallel: - # Parallel mode execution - yield from self._execute_parallel_iterations( - iterator_list_value=iterator_list_value, - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - else: - # Sequential mode execution - for index, item in enumerate(iterator_list_value): - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - yield IterationNextEvent(index=index) - - graph_engine = self._create_graph_engine(index, item) - - # Run the iteration - yield from self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs, - graph_engine=graph_engine, - ) - - # Sync conversation variables after each iteration completes - self._sync_conversation_variables_from_snapshot( - self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool - ) - ) - - # Accumulate usage from this iteration - usage_accumulator[0] = self._merge_usage( - usage_accumulator[0], graph_engine.graph_runtime_state.llm_usage - ) - iter_run_map[str(index)] = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - def _execute_parallel_iterations( - self, - iterator_list_value: Sequence[object], - outputs: list[object], - iter_run_map: dict[str, float], - usage_accumulator: list[LLMUsage], - ) -> Generator[GraphNodeEventBase | NodeEventBase, None, None]: - # Initialize outputs list with None values to maintain order - outputs.extend([None] * len(iterator_list_value)) - - # Determine the number of parallel workers - max_workers = min(self.node_data.parallel_nums, len(iterator_list_value)) - - with ThreadPoolExecutor(max_workers=max_workers) as executor: - # Submit all iteration tasks - future_to_index: dict[ - Future[ - tuple[ - float, - list[GraphNodeEventBase], - object | None, - dict[str, Variable], - LLMUsage, - ] - ], - int, - ] = {} - for index, item in enumerate(iterator_list_value): - yield IterationNextEvent(index=index) - future = executor.submit( - self._execute_single_iteration_parallel, - index=index, - item=item, - execution_context=self._capture_execution_context(), - ) - future_to_index[future] = index - - # Process completed iterations as they finish - for future in as_completed(future_to_index): - index = future_to_index[future] - try: - result = future.result() - ( - iteration_duration, - events, - output_value, - conversation_snapshot, - iteration_usage, - ) = result - - # Update outputs at the correct index - outputs[index] = output_value - - # Yield all events from this iteration - yield from events - - # The worker computes duration before we replay buffered events here, - # so slow downstream consumers don't inflate per-iteration timing. - iter_run_map[str(index)] = iteration_duration - - usage_accumulator[0] = self._merge_usage(usage_accumulator[0], iteration_usage) - - # Sync conversation variables after iteration completion - self._sync_conversation_variables_from_snapshot(conversation_snapshot) - - except Exception as e: - # Handle errors based on error_handle_mode - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - # Cancel remaining futures and re-raise - for f in future_to_index: - if f != future: - f.cancel() - raise IterationNodeError(str(e)) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs[index] = None - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[index] = None # Will be filtered later - - # Remove None values if in REMOVE_ABNORMAL_OUTPUT mode - if self.node_data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - outputs[:] = [output for output in outputs if output is not None] - - def _execute_single_iteration_parallel( - self, - index: int, - item: object, - execution_context: "IExecutionContext", - ) -> tuple[float, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]: - """Execute a single iteration in parallel mode and return results.""" - with execution_context: - iter_start_at = datetime.now(UTC).replace(tzinfo=None) - events: list[GraphNodeEventBase] = [] - outputs_temp: list[object] = [] - - graph_engine = self._create_graph_engine(index, item) - - # Collect events instead of yielding them directly - for event in self._run_single_iter( - variable_pool=graph_engine.graph_runtime_state.variable_pool, - outputs=outputs_temp, - graph_engine=graph_engine, - ): - events.append(event) - - # Get the output value from the temporary outputs list - output_value = outputs_temp[0] if outputs_temp else None - conversation_snapshot = self._extract_conversation_variable_snapshot( - variable_pool=graph_engine.graph_runtime_state.variable_pool - ) - iteration_duration = (datetime.now(UTC).replace(tzinfo=None) - iter_start_at).total_seconds() - - return ( - iteration_duration, - events, - output_value, - conversation_snapshot, - graph_engine.graph_runtime_state.llm_usage, - ) - - def _capture_execution_context(self) -> "IExecutionContext": - """Capture current execution context for parallel iterations.""" - from dify_graph.context import capture_current_context - - return capture_current_context() - - def _handle_iteration_success( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationSucceededEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - ) - - # Yield final success event - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"output": flattened_outputs}, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - def _flatten_outputs_if_needed(self, outputs: list[object]) -> list[object]: - """ - Flatten the outputs list if all elements are lists. - This maintains backward compatibility with version 1.8.1 behavior. - - If flatten_output is False, returns outputs as-is (nested structure). - If flatten_output is True (default), flattens the list if all elements are lists. - """ - # If flatten_output is disabled, return outputs as-is - if not self.node_data.flatten_output: - return outputs - - if not outputs: - return outputs - - # Check if all non-None outputs are lists - non_none_outputs: list[object] = [output for output in outputs if output is not None] - if not non_none_outputs: - return outputs - - if all(isinstance(output, list) for output in non_none_outputs): - # Flatten the list of lists - flattened: list[Any] = [] - for output in outputs: - if isinstance(output, list): - flattened.extend(output) - elif output is not None: - # This shouldn't happen based on our check, but handle it gracefully - flattened.append(output) - return flattened - - return outputs - - def _handle_iteration_failure( - self, - started_at: datetime, - inputs: dict[str, Sequence[object]], - outputs: list[object], - iterator_list_value: Sequence[object], - iter_run_map: dict[str, float], - *, - usage: LLMUsage, - error: IterationNodeError, - ) -> Generator[NodeEventBase, None, None]: - # Flatten the list of lists if all outputs are lists (even in failure case) - flattened_outputs = self._flatten_outputs_if_needed(outputs) - - yield IterationFailedEvent( - start_at=started_at, - inputs=inputs, - outputs={"output": flattened_outputs}, - steps=len(iterator_list_value), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - WorkflowNodeExecutionMetadataKey.ITERATION_DURATION_MAP: iter_run_map, - }, - error=str(error), - ) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(error), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: IterationNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping: dict[str, Sequence[str]] = { - f"{node_id}.input_selector": node_data.iterator_selector, - } - iteration_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_config_data = node.get("data", {}) - if node_config_data.get("iteration_id") == node_id: - in_iteration_node_id = node.get("id") - if in_iteration_node_id: - iteration_node_ids.add(in_iteration_node_id) - - # Get node configs from graph_config instead of non-existent node_id_config_mapping - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("iteration_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove iteration variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - # remove variable out from iteration - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in iteration_node_ids} - - return variable_mapping - - def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]: - conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()} - - def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None: - parent_pool = self.graph_runtime_state.variable_pool - parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {}) - - current_keys = set(parent_conversations.keys()) - snapshot_keys = set(snapshot.keys()) - - for removed_key in current_keys - snapshot_keys: - parent_pool.remove((CONVERSATION_VARIABLE_NODE_ID, removed_key)) - - for name, variable in snapshot.items(): - parent_pool.add((CONVERSATION_VARIABLE_NODE_ID, name), variable) - - def _append_iteration_info_to_event( - self, - event: GraphNodeEventBase, - iter_run_index: int, - ): - event.in_iteration_id = self._node_id - iter_metadata = { - WorkflowNodeExecutionMetadataKey.ITERATION_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: iter_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.ITERATION_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **iter_metadata} - - def _run_single_iter( - self, - *, - variable_pool: VariablePool, - outputs: list[object], - graph_engine: "GraphEngine", - ) -> Generator[GraphNodeEventBase, None, None]: - rst = graph_engine.run() - # get current iteration index - index_variable = variable_pool.get([self._node_id, "index"]) - if not isinstance(index_variable, IntegerVariable): - raise IterationIndexNotFoundError(f"iteration {self._node_id} current index not found") - current_index = index_variable.value - for event in rst: - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.ITERATION_START: - continue - - if isinstance(event, GraphNodeEventBase): - self._append_iteration_info_to_event(event=event, iter_run_index=current_index) - yield event - elif isinstance(event, (GraphRunSucceededEvent, GraphRunPartialSucceededEvent)): - result = variable_pool.get(self.node_data.output_selector) - if result is None: - outputs.append(None) - else: - outputs.append(result.to_object()) - return - elif isinstance(event, GraphRunFailedEvent): - match self.node_data.error_handle_mode: - case ErrorHandleMode.TERMINATED: - raise IterationNodeError(event.error) - case ErrorHandleMode.CONTINUE_ON_ERROR: - outputs.append(None) - return - case ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT: - return - - def _create_graph_engine(self, index: int, item: object): - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import ChildGraphNotFoundError, GraphRuntimeState - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - # Create a deep copy of the variable pool for each iteration - variable_pool_copy = self.graph_runtime_state.variable_pool.model_copy(deep=True) - - # append iteration variable (item, index) to variable pool - variable_pool_copy.add([self._node_id, "index"], index) - variable_pool_copy.add([self._node_id, "item"], item) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=variable_pool_copy, - start_at=self.graph_runtime_state.start_at, - total_tokens=0, - node_run_steps=0, - ) - root_node_id = self.node_data.start_node_id - if root_node_id is None: - raise StartNodeIdNotFoundError(f"field start_node_id in iteration {self._node_id} not found") - - try: - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, - root_node_id=root_node_id, - ) - except ChildGraphNotFoundError as exc: - raise IterationGraphNotFoundError("iteration graph not found") from exc diff --git a/api/dify_graph/nodes/iteration/iteration_start_node.py b/api/dify_graph/nodes/iteration/iteration_start_node.py deleted file mode 100644 index a8ecf3d83b..0000000000 --- a/api/dify_graph/nodes/iteration/iteration_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.iteration.entities import IterationStartNodeData - - -class IterationStartNode(Node[IterationStartNodeData]): - """ - Iteration Start Node. - """ - - node_type = BuiltinNodeTypes.ITERATION_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/dify_graph/nodes/list_operator/__init__.py b/api/dify_graph/nodes/list_operator/__init__.py deleted file mode 100644 index 1877586ef4..0000000000 --- a/api/dify_graph/nodes/list_operator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import ListOperatorNode - -__all__ = ["ListOperatorNode"] diff --git a/api/dify_graph/nodes/list_operator/entities.py b/api/dify_graph/nodes/list_operator/entities.py deleted file mode 100644 index 41b3a40b78..0000000000 --- a/api/dify_graph/nodes/list_operator/entities.py +++ /dev/null @@ -1,71 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class FilterOperator(StrEnum): - # string conditions - CONTAINS = "contains" - START_WITH = "start with" - END_WITH = "end with" - IS = "is" - IN = "in" - EMPTY = "empty" - NOT_CONTAINS = "not contains" - IS_NOT = "is not" - NOT_IN = "not in" - NOT_EMPTY = "not empty" - # number conditions - EQUAL = "=" - NOT_EQUAL = "≠" - LESS_THAN = "<" - GREATER_THAN = ">" - GREATER_THAN_OR_EQUAL = "≥" - LESS_THAN_OR_EQUAL = "≤" - - -class Order(StrEnum): - ASC = "asc" - DESC = "desc" - - -class FilterCondition(BaseModel): - key: str = "" - comparison_operator: FilterOperator = FilterOperator.CONTAINS - # the value is bool if the filter operator is comparing with - # a boolean constant. - value: str | Sequence[str] | bool = "" - - -class FilterBy(BaseModel): - enabled: bool = False - conditions: Sequence[FilterCondition] = Field(default_factory=list) - - -class OrderByConfig(BaseModel): - enabled: bool = False - key: str = "" - value: Order = Order.ASC - - -class Limit(BaseModel): - enabled: bool = False - size: int = -1 - - -class ExtractConfig(BaseModel): - enabled: bool = False - serial: str = "1" - - -class ListOperatorNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LIST_OPERATOR - variable: Sequence[str] = Field(default_factory=list) - filter_by: FilterBy - order_by: OrderByConfig - limit: Limit - extract_by: ExtractConfig = Field(default_factory=ExtractConfig) diff --git a/api/dify_graph/nodes/list_operator/exc.py b/api/dify_graph/nodes/list_operator/exc.py deleted file mode 100644 index f88aa0be29..0000000000 --- a/api/dify_graph/nodes/list_operator/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class ListOperatorError(ValueError): - """Base class for all ListOperator errors.""" - - pass - - -class InvalidFilterValueError(ListOperatorError): - pass - - -class InvalidKeyError(ListOperatorError): - pass - - -class InvalidConditionError(ListOperatorError): - pass diff --git a/api/dify_graph/nodes/list_operator/node.py b/api/dify_graph/nodes/list_operator/node.py deleted file mode 100644 index dc8b8904f7..0000000000 --- a/api/dify_graph/nodes/list_operator/node.py +++ /dev/null @@ -1,345 +0,0 @@ -from collections.abc import Callable, Sequence -from typing import Any, TypeAlias, TypeVar - -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.variables import ArrayFileSegment, ArrayNumberSegment, ArrayStringSegment -from dify_graph.variables.segments import ArrayAnySegment, ArrayBooleanSegment, ArraySegment - -from .entities import FilterOperator, ListOperatorNodeData, Order -from .exc import InvalidConditionError, InvalidFilterValueError, InvalidKeyError, ListOperatorError - -_SUPPORTED_TYPES_TUPLE = ( - ArrayFileSegment, - ArrayNumberSegment, - ArrayStringSegment, - ArrayBooleanSegment, -) -_SUPPORTED_TYPES_ALIAS: TypeAlias = ArrayFileSegment | ArrayNumberSegment | ArrayStringSegment | ArrayBooleanSegment - - -_T = TypeVar("_T") - - -def _negation(filter_: Callable[[_T], bool]) -> Callable[[_T], bool]: - """Returns the negation of a given filter function. If the original filter - returns `True` for a value, the negated filter will return `False`, and vice versa. - """ - - def wrapper(value: _T) -> bool: - return not filter_(value) - - return wrapper - - -class ListOperatorNode(Node[ListOperatorNodeData]): - node_type = BuiltinNodeTypes.LIST_OPERATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - inputs: dict[str, Sequence[object]] = {} - process_data: dict[str, Sequence[object]] = {} - outputs: dict[str, Any] = {} - - variable = self.graph_runtime_state.variable_pool.get(self.node_data.variable) - if variable is None: - error_message = f"Variable not found for selector: {self.node_data.variable}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - if not variable.value: - inputs = {"variable": []} - process_data = {"variable": []} - if isinstance(variable, ArraySegment): - result = variable.model_copy(update={"value": []}) - else: - result = ArrayAnySegment(value=[]) - outputs = {"result": result, "first_record": None, "last_record": None} - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - if not isinstance(variable, _SUPPORTED_TYPES_TUPLE): - error_message = f"Variable {self.node_data.variable} is not an array type, actual type: {type(variable)}" - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, error=error_message, inputs=inputs, outputs=outputs - ) - - if isinstance(variable, ArrayFileSegment): - inputs = {"variable": [item.to_dict() for item in variable.value]} - process_data["variable"] = [item.to_dict() for item in variable.value] - else: - inputs = {"variable": variable.value} - process_data["variable"] = variable.value - - try: - # Filter - if self.node_data.filter_by.enabled: - variable = self._apply_filter(variable) - - # Extract - if self.node_data.extract_by.enabled: - variable = self._extract_slice(variable) - - # Order - if self.node_data.order_by.enabled: - variable = self._apply_order(variable) - - # Slice - if self.node_data.limit.enabled: - variable = self._apply_slice(variable) - - outputs = { - "result": variable, - "first_record": variable.value[0] if variable.value else None, - "last_record": variable.value[-1] if variable.value else None, - } - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - except ListOperatorError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=inputs, - process_data=process_data, - outputs=outputs, - ) - - def _apply_filter(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - filter_func: Callable[[Any], bool] - result: list[Any] = [] - for condition in self.node_data.filter_by.conditions: - if isinstance(variable, ArrayStringSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_string_filter_func(condition=condition.comparison_operator, value=value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayNumberSegment): - if not isinstance(condition.value, str): - raise InvalidFilterValueError(f"Invalid filter value: {condition.value}") - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - filter_func = _get_number_filter_func(condition=condition.comparison_operator, value=float(value)) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - elif isinstance(variable, ArrayFileSegment): - if isinstance(condition.value, str): - value = self.graph_runtime_state.variable_pool.convert_template(condition.value).text - elif isinstance(condition.value, bool): - raise ValueError(f"File filter expects a string value, got {type(condition.value)}") - else: - value = condition.value - filter_func = _get_file_filter_func( - key=condition.key, - condition=condition.comparison_operator, - value=value, - ) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - else: - if not isinstance(condition.value, bool): - raise ValueError(f"Boolean filter expects a boolean value, got {type(condition.value)}") - filter_func = _get_boolean_filter_func(condition=condition.comparison_operator, value=condition.value) - result = list(filter(filter_func, variable.value)) - variable = variable.model_copy(update={"value": result}) - return variable - - def _apply_order(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - if isinstance(variable, (ArrayStringSegment, ArrayNumberSegment, ArrayBooleanSegment)): - result = sorted(variable.value, reverse=self.node_data.order_by.value == Order.DESC) - variable = variable.model_copy(update={"value": result}) - else: - result = _order_file( - order=self.node_data.order_by.value, order_by=self.node_data.order_by.key, array=variable.value - ) - variable = variable.model_copy(update={"value": result}) - - return variable - - def _apply_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - result = variable.value[: self.node_data.limit.size] - return variable.model_copy(update={"value": result}) - - def _extract_slice(self, variable: _SUPPORTED_TYPES_ALIAS) -> _SUPPORTED_TYPES_ALIAS: - value = int(self.graph_runtime_state.variable_pool.convert_template(self.node_data.extract_by.serial).text) - if value < 1: - raise ValueError(f"Invalid serial index: must be >= 1, got {value}") - if value > len(variable.value): - raise InvalidKeyError(f"Invalid serial index: must be <= {len(variable.value)}, got {value}") - value -= 1 - result = variable.value[value] - return variable.model_copy(update={"value": [result]}) - - -def _get_file_extract_number_func(*, key: str) -> Callable[[File], int]: - match key: - case "size": - return lambda x: x.size - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_file_extract_string_func(*, key: str) -> Callable[[File], str]: - match key: - case "name": - return lambda x: x.filename or "" - case "type": - return lambda x: str(x.type) - case "extension": - return lambda x: x.extension or "" - case "mime_type": - return lambda x: x.mime_type or "" - case "transfer_method": - return lambda x: str(x.transfer_method) - case "url": - return lambda x: x.remote_url or "" - case "related_id": - return lambda x: x.related_id or "" - case _: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _get_string_filter_func(*, condition: str, value: str) -> Callable[[str], bool]: - match condition: - case "contains": - return _contains(value) - case "start with": - return _startswith(value) - case "end with": - return _endswith(value) - case "is": - return _is(value) - case "in": - return _in(value) - case "empty": - return lambda x: x == "" - case "not contains": - return _negation(_contains(value)) - case "is not": - return _negation(_is(value)) - case "not in": - return _negation(_in(value)) - case "not empty": - return lambda x: x != "" - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_sequence_filter_func(*, condition: str, value: Sequence[str]) -> Callable[[str], bool]: - match condition: - case "in": - return _in(value) - case "not in": - return _negation(_in(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_number_filter_func(*, condition: str, value: int | float) -> Callable[[int | float], bool]: - match condition: - case "=": - return _eq(value) - case "≠": - return _ne(value) - case "<": - return _lt(value) - case "≤": - return _le(value) - case ">": - return _gt(value) - case "≥": - return _ge(value) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_boolean_filter_func(*, condition: FilterOperator, value: bool) -> Callable[[bool], bool]: - match condition: - case FilterOperator.IS: - return _is(value) - case FilterOperator.IS_NOT: - return _negation(_is(value)) - case _: - raise InvalidConditionError(f"Invalid condition: {condition}") - - -def _get_file_filter_func(*, key: str, condition: str, value: str | Sequence[str]) -> Callable[[File], bool]: - if key in {"name", "extension", "mime_type", "url", "related_id"} and isinstance(value, str): - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_string_filter_func(condition=condition, value=value)(extract_func(x)) - if key in {"type", "transfer_method"}: - extract_func = _get_file_extract_string_func(key=key) - return lambda x: _get_sequence_filter_func(condition=condition, value=value)(extract_func(x)) - elif key == "size" and isinstance(value, str): - extract_number = _get_file_extract_number_func(key=key) - return lambda x: _get_number_filter_func(condition=condition, value=float(value))(extract_number(x)) - else: - raise InvalidKeyError(f"Invalid key: {key}") - - -def _contains(value: str) -> Callable[[str], bool]: - return lambda x: value in x - - -def _startswith(value: str) -> Callable[[str], bool]: - return lambda x: x.startswith(value) - - -def _endswith(value: str) -> Callable[[str], bool]: - return lambda x: x.endswith(value) - - -def _is(value: _T) -> Callable[[_T], bool]: - return lambda x: x == value - - -def _in(value: str | Sequence[str]) -> Callable[[str], bool]: - return lambda x: x in value - - -def _eq(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x == value - - -def _ne(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x != value - - -def _lt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x < value - - -def _le(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x <= value - - -def _gt(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x > value - - -def _ge(value: int | float) -> Callable[[int | float], bool]: - return lambda x: x >= value - - -def _order_file(*, order: Order, order_by: str = "", array: Sequence[File]): - extract_func: Callable[[File], Any] - if order_by in {"name", "type", "extension", "mime_type", "transfer_method", "url", "related_id"}: - extract_func = _get_file_extract_string_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - elif order_by == "size": - extract_func = _get_file_extract_number_func(key=order_by) - return sorted(array, key=lambda x: extract_func(x), reverse=order == Order.DESC) - else: - raise InvalidKeyError(f"Invalid order key: {order_by}") diff --git a/api/dify_graph/nodes/llm/__init__.py b/api/dify_graph/nodes/llm/__init__.py deleted file mode 100644 index f7bc713f63..0000000000 --- a/api/dify_graph/nodes/llm/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from .node import LLMNode - -__all__ = [ - "LLMNode", - "LLMNodeChatModelMessage", - "LLMNodeCompletionModelPromptTemplate", - "LLMNodeData", - "ModelConfig", - "VisionConfig", -] diff --git a/api/dify_graph/nodes/llm/entities.py b/api/dify_graph/nodes/llm/entities.py deleted file mode 100644 index 6ca01a21da..0000000000 --- a/api/dify_graph/nodes/llm/entities.py +++ /dev/null @@ -1,100 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Literal - -from pydantic import BaseModel, Field, field_validator - -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode -from dify_graph.nodes.base.entities import VariableSelector - - -class ModelConfig(BaseModel): - provider: str - name: str - mode: LLMMode - completion_params: dict[str, Any] = Field(default_factory=dict) - - -class ContextConfig(BaseModel): - enabled: bool - variable_selector: list[str] | None = None - - -class VisionConfigOptions(BaseModel): - variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"]) - detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH - - -class VisionConfig(BaseModel): - enabled: bool = False - configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions) - - @field_validator("configs", mode="before") - @classmethod - def convert_none_configs(cls, v: Any): - if v is None: - return VisionConfigOptions() - return v - - -class PromptConfig(BaseModel): - jinja2_variables: Sequence[VariableSelector] = Field(default_factory=list) - - @field_validator("jinja2_variables", mode="before") - @classmethod - def convert_none_jinja2_variables(cls, v: Any): - if v is None: - return [] - return v - - -class LLMNodeChatModelMessage(ChatModelMessage): - text: str = "" - jinja2_text: str | None = None - - -class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate): - jinja2_text: str | None = None - - -class LLMNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.LLM - model: ModelConfig - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate - prompt_config: PromptConfig = Field(default_factory=PromptConfig) - memory: MemoryConfig | None = None - context: ContextConfig - vision: VisionConfig = Field(default_factory=VisionConfig) - structured_output: Mapping[str, Any] | None = None - # We used 'structured_output_enabled' in the past, but it's not a good name. - structured_output_switch_on: bool = Field(False, alias="structured_output_enabled") - reasoning_format: Literal["separated", "tagged"] = Field( - # Keep tagged as default for backward compatibility - default="tagged", - description=( - """ - Strategy for handling model reasoning output. - - separated: Return clean text (without tags) + reasoning_content field. - Recommended for new workflows. Enables safe downstream parsing and - workflow variable access: {{#node_id.reasoning_content#}} - - tagged : Return original text (with tags) + reasoning_content field. - Maintains full backward compatibility while still providing reasoning_content - for workflow automation. Frontend thinking panels work as before. - """ - ), - ) - - @field_validator("prompt_config", mode="before") - @classmethod - def convert_none_prompt_config(cls, v: Any): - if v is None: - return PromptConfig() - return v - - @property - def structured_output_enabled(self) -> bool: - return self.structured_output_switch_on and self.structured_output is not None diff --git a/api/dify_graph/nodes/llm/exc.py b/api/dify_graph/nodes/llm/exc.py deleted file mode 100644 index 4d16095296..0000000000 --- a/api/dify_graph/nodes/llm/exc.py +++ /dev/null @@ -1,45 +0,0 @@ -class LLMNodeError(ValueError): - """Base class for LLM Node errors.""" - - -class VariableNotFoundError(LLMNodeError): - """Raised when a required variable is not found.""" - - -class InvalidContextStructureError(LLMNodeError): - """Raised when the context structure is invalid.""" - - -class InvalidVariableTypeError(LLMNodeError): - """Raised when the variable type is invalid.""" - - -class ModelNotExistError(LLMNodeError): - """Raised when the specified model does not exist.""" - - -class LLMModeRequiredError(LLMNodeError): - """Raised when LLM mode is required but not provided.""" - - -class NoPromptFoundError(LLMNodeError): - """Raised when no prompt is found in the LLM configuration.""" - - -class TemplateTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt type {type_name} is not supported.") - - -class MemoryRolePrefixRequiredError(LLMNodeError): - """Raised when memory role prefix is required for completion model.""" - - -class FileTypeNotSupportError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"{type_name} type is not supported by this model") - - -class UnsupportedPromptContentTypeError(LLMNodeError): - def __init__(self, *, type_name: str): - super().__init__(f"Prompt content type {type_name} is not supported.") diff --git a/api/dify_graph/nodes/llm/file_saver.py b/api/dify_graph/nodes/llm/file_saver.py deleted file mode 100644 index 50e52a3b6f..0000000000 --- a/api/dify_graph/nodes/llm/file_saver.py +++ /dev/null @@ -1,144 +0,0 @@ -import mimetypes -import typing as tp - -from constants.mimetypes import DEFAULT_EXTENSION, DEFAULT_MIME_TYPE -from core.tools.signature import sign_tool_file -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.protocols import HttpClientProtocol - - -class LLMFileSaver(tp.Protocol): - """LLMFileSaver is responsible for save multimodal output returned by - LLM. - """ - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - """save_binary_string saves the inline file data returned by LLM. - - Currently (2025-04-30), only some of Google Gemini models will return - multimodal output as inline data. - - :param data: the contents of the file - :param mime_type: the media type of the file, specified by rfc6838 - (https://datatracker.ietf.org/doc/html/rfc6838) - :param file_type: The file type of the inline file. - :param extension_override: Override the auto-detected file extension while saving this file. - - The default value is `None`, which means do not override the file extension and guessing it - from the `mime_type` attribute while saving the file. - - Setting it to values other than `None` means override the file's extension, and - will bypass the extension guessing saving the file. - - Specially, setting it to empty string (`""`) will leave the file extension empty. - - When it is not `None` or empty string (`""`), it should be a string beginning with a - dot (`.`). For example, `.py` and `.tar.gz` are both valid values, while `py` - and `tar.gz` are not. - """ - raise NotImplementedError() - - def save_remote_url(self, url: str, file_type: FileType) -> File: - """save_remote_url saves the file from a remote url returned by LLM. - - Currently (2025-04-30), no model returns multimodel output as a url. - - :param url: the url of the file. - :param file_type: the file type of the file, check `FileType` enum for reference. - """ - raise NotImplementedError() - - -class FileSaverImpl(LLMFileSaver): - _tenant_id: str - _user_id: str - - def __init__(self, user_id: str, tenant_id: str, http_client: HttpClientProtocol): - self._user_id = user_id - self._tenant_id = tenant_id - self._http_client = http_client - - def _get_tool_file_manager(self): - return ToolFileManager() - - def save_remote_url(self, url: str, file_type: FileType) -> File: - http_response = self._http_client.get(url) - http_response.raise_for_status() - data = http_response.content - mime_type_from_header = http_response.headers.get("Content-Type") - mime_type, extension = _extract_content_type_and_extension(url, mime_type_from_header) - return self.save_binary_string(data, mime_type, file_type, extension_override=extension) - - def save_binary_string( - self, - data: bytes, - mime_type: str, - file_type: FileType, - extension_override: str | None = None, - ) -> File: - tool_file_manager = self._get_tool_file_manager() - tool_file = tool_file_manager.create_file_by_raw( - user_id=self._user_id, - tenant_id=self._tenant_id, - # TODO(QuantumGhost): what is conversation id? - conversation_id=None, - file_binary=data, - mimetype=mime_type, - ) - extension_override = _validate_extension_override(extension_override) - extension = _get_extension(mime_type, extension_override) - url = sign_tool_file(tool_file.id, extension) - - return File( - tenant_id=self._tenant_id, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - filename=tool_file.name, - extension=extension, - mime_type=mime_type, - size=len(data), - related_id=tool_file.id, - url=url, - storage_key=tool_file.file_key, - ) - - -def _get_extension(mime_type: str, extension_override: str | None = None) -> str: - """get_extension return the extension of file. - - If the `extension_override` parameter is set, this function should honor it and - return its value. - """ - if extension_override is not None: - return extension_override - return mimetypes.guess_extension(mime_type) or DEFAULT_EXTENSION - - -def _extract_content_type_and_extension(url: str, content_type_header: str | None) -> tuple[str, str]: - """_extract_content_type_and_extension tries to - guess content type of file from url and `Content-Type` header in response. - """ - if content_type_header: - extension = mimetypes.guess_extension(content_type_header) or DEFAULT_EXTENSION - return content_type_header, extension - content_type = mimetypes.guess_type(url)[0] or DEFAULT_MIME_TYPE - extension = mimetypes.guess_extension(content_type) or DEFAULT_EXTENSION - return content_type, extension - - -def _validate_extension_override(extension_override: str | None) -> str | None: - # `extension_override` is allow to be `None or `""`. - if extension_override is None: - return None - if extension_override == "": - return "" - if not extension_override.startswith("."): - raise ValueError("extension_override should start with '.' if not None or empty.", extension_override) - return extension_override diff --git a/api/dify_graph/nodes/llm/llm_utils.py b/api/dify_graph/nodes/llm/llm_utils.py deleted file mode 100644 index 8682c3682c..0000000000 --- a/api/dify_graph/nodes/llm/llm_utils.py +++ /dev/null @@ -1,543 +0,0 @@ -from __future__ import annotations - -import json -import logging -import re -from collections.abc import Mapping, Sequence -from typing import Any, cast - -from core.model_manager import ModelInstance -from dify_graph.file import FileType, file_manager -from dify_graph.file.models import File -from dify_graph.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - PromptMessageContentType, - PromptMessageRole, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessageContentUnionTypes, - SystemPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.runtime import VariablePool -from dify_graph.variables import ArrayFileSegment, FileSegment -from dify_graph.variables.segments import ArrayAnySegment, NoneSegment - -from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig -from .exc import ( - InvalidVariableTypeError, - MemoryRolePrefixRequiredError, - NoPromptFoundError, - TemplateTypeNotSupportError, -) -from .protocols import TemplateRenderer - -logger = logging.getLogger(__name__) - -VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}") -MAX_RESOLVED_VALUE_LENGTH = 1024 - - -def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity: - model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema( - model_instance.model_name, - dict(model_instance.credentials), - ) - if not model_schema: - raise ValueError(f"Model schema not found for {model_instance.model_name}") - return model_schema - - -def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]: - variable = variable_pool.get(selector) - if variable is None: - return [] - elif isinstance(variable, FileSegment): - return [variable.value] - elif isinstance(variable, ArrayFileSegment): - return variable.value - elif isinstance(variable, NoneSegment | ArrayAnySegment): - return [] - raise InvalidVariableTypeError(f"Invalid variable type: {type(variable)}") - - -def convert_history_messages_to_text( - *, - history_messages: Sequence[PromptMessage], - human_prefix: str, - ai_prefix: str, -) -> str: - string_messages: list[str] = [] - for message in history_messages: - if message.role == PromptMessageRole.USER: - role = human_prefix - elif message.role == PromptMessageRole.ASSISTANT: - role = ai_prefix - else: - continue - - if isinstance(message.content, list): - content_parts = [] - for content in message.content: - if isinstance(content, TextPromptMessageContent): - content_parts.append(content.data) - elif isinstance(content, ImagePromptMessageContent): - content_parts.append("[image]") - - inner_msg = "\n".join(content_parts) - string_messages.append(f"{role}: {inner_msg}") - else: - string_messages.append(f"{role}: {message.content}") - - return "\n".join(string_messages) - - -def fetch_memory_text( - *, - memory: PromptMessageMemory, - max_token_limit: int, - message_limit: int | None = None, - human_prefix: str = "Human", - ai_prefix: str = "Assistant", -) -> str: - history_messages = memory.get_history_prompt_messages( - max_token_limit=max_token_limit, - message_limit=message_limit, - ) - return convert_history_messages_to_text( - history_messages=history_messages, - human_prefix=human_prefix, - ai_prefix=ai_prefix, - ) - - -def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str | None = None, - memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, -) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - prompt_messages: list[PromptMessage] = [] - model_schema = fetch_model_schema(model_instance=model_instance) - - if isinstance(prompt_template, list): - prompt_messages.extend( - handle_list_messages( - messages=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - - prompt_messages.extend( - handle_memory_chat_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - ) - - if sys_query: - prompt_messages.extend( - handle_list_messages( - messages=[ - LLMNodeChatModelMessage( - text=sys_query, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context="", - jinja2_variables=[], - variable_pool=variable_pool, - vision_detail_config=vision_detail, - template_renderer=template_renderer, - ) - ) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - prompt_messages.extend( - handle_completion_template( - template=prompt_template, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - ) - - memory_text = handle_memory_completion_mode( - memory=memory, - memory_config=memory_config, - model_instance=model_instance, - ) - prompt_content = prompt_messages[0].content - if isinstance(prompt_content, str): - prompt_content = str(prompt_content) - if "#histories#" in prompt_content: - prompt_content = prompt_content.replace("#histories#", memory_text) - else: - prompt_content = memory_text + "\n" + prompt_content - prompt_messages[0].content = prompt_content - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - if "#histories#" in content_item.data: - content_item.data = content_item.data.replace("#histories#", memory_text) - else: - content_item.data = memory_text + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - - if sys_query: - if isinstance(prompt_content, str): - prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query) - elif isinstance(prompt_content, list): - for content_item in prompt_content: - if isinstance(content_item, TextPromptMessageContent): - content_item.data = sys_query + "\n" + content_item.data - else: - raise ValueError("Invalid prompt content type") - else: - raise TemplateTypeNotSupportError(type_name=str(type(prompt_template))) - - _append_file_prompts( - prompt_messages=prompt_messages, - files=sys_files, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - _append_file_prompts( - prompt_messages=prompt_messages, - files=context_files or [], - vision_enabled=vision_enabled, - vision_detail=vision_detail, - ) - - filtered_prompt_messages: list[PromptMessage] = [] - for prompt_message in prompt_messages: - if isinstance(prompt_message.content, list): - prompt_message_content: list[PromptMessageContentUnionTypes] = [] - for content_item in prompt_message.content: - if not model_schema.features: - if content_item.type == PromptMessageContentType.TEXT: - prompt_message_content.append(content_item) - continue - - if ( - ( - content_item.type == PromptMessageContentType.IMAGE - and ModelFeature.VISION not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.DOCUMENT - and ModelFeature.DOCUMENT not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.VIDEO - and ModelFeature.VIDEO not in model_schema.features - ) - or ( - content_item.type == PromptMessageContentType.AUDIO - and ModelFeature.AUDIO not in model_schema.features - ) - ): - continue - prompt_message_content.append(content_item) - if not prompt_message_content: - continue - if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT: - prompt_message.content = prompt_message_content[0].data - else: - prompt_message.content = prompt_message_content - filtered_prompt_messages.append(prompt_message) - elif not prompt_message.is_empty(): - filtered_prompt_messages.append(prompt_message) - - if len(filtered_prompt_messages) == 0: - raise NoPromptFoundError( - "No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding." - ) - - return filtered_prompt_messages, stop - - -def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str | None, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - prompt_messages: list[PromptMessage] = [] - for message in messages: - if message.edition_type == "jinja2": - result_text = render_jinja2_message( - template=message.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=message.role, - ) - ) - continue - - template = message.text.replace("{#context#}", context) if context else message.text - segment_group = variable_pool.convert_template(template) - file_contents: list[PromptMessageContentUnionTypes] = [] - for segment in segment_group.value: - if isinstance(segment, ArrayFileSegment): - for file in segment.value: - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - elif isinstance(segment, FileSegment): - file = segment.value - if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}: - file_contents.append( - file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config) - ) - - if segment_group.text: - prompt_messages.append( - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=segment_group.text)], - role=message.role, - ) - ) - if file_contents: - prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role)) - - return prompt_messages - - -def render_jinja2_message( - *, - template: str, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, -) -> str: - if not template: - return "" - if template_renderer is None: - raise ValueError("template_renderer is required for jinja2 prompt rendering") - - jinja2_inputs: dict[str, Any] = {} - for jinja2_variable in jinja2_variables: - variable = variable_pool.get(jinja2_variable.value_selector) - jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else "" - return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs) - - -def handle_completion_template( - *, - template: LLMNodeCompletionModelPromptTemplate, - context: str | None, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - template_renderer: TemplateRenderer | None = None, -) -> Sequence[PromptMessage]: - if template.edition_type == "jinja2": - result_text = render_jinja2_message( - template=template.jinja2_text or "", - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - template_renderer=template_renderer, - ) - else: - template_text = template.text.replace("{#context#}", context) if context else template.text - result_text = variable_pool.convert_template(template_text).text - return [ - combine_message_content_with_role( - contents=[TextPromptMessageContent(data=result_text)], - role=PromptMessageRole.USER, - ) - ] - - -def combine_message_content_with_role( - *, - contents: str | list[PromptMessageContentUnionTypes] | None = None, - role: PromptMessageRole, -) -> PromptMessage: - match role: - case PromptMessageRole.USER: - return UserPromptMessage(content=contents) - case PromptMessageRole.ASSISTANT: - return AssistantPromptMessage(content=contents) - case PromptMessageRole.SYSTEM: - return SystemPromptMessage(content=contents) - case _: - raise NotImplementedError(f"Role {role} is not supported") - - -def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int: - rest_tokens = 2000 - runtime_model_schema = fetch_model_schema(model_instance=model_instance) - runtime_model_parameters = model_instance.parameters - - model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in runtime_model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - runtime_model_parameters.get(parameter_rule.name) - or runtime_model_parameters.get(str(parameter_rule.use_template)) - or 0 - ) - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - -def handle_memory_chat_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: ModelInstance, -) -> Sequence[PromptMessage]: - if not memory or not memory_config: - return [] - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - return memory.get_history_prompt_messages( - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - ) - - -def handle_memory_completion_mode( - *, - memory: PromptMessageMemory | None, - memory_config: MemoryConfig | None, - model_instance: ModelInstance, -) -> str: - if not memory or not memory_config: - return "" - - rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance) - if not memory_config.role_prefix: - raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.") - - return fetch_memory_text( - memory=memory, - max_token_limit=rest_tokens, - message_limit=memory_config.window.size if memory_config.window.enabled else None, - human_prefix=memory_config.role_prefix.user, - ai_prefix=memory_config.role_prefix.assistant, - ) - - -def _append_file_prompts( - *, - prompt_messages: list[PromptMessage], - files: Sequence[File], - vision_enabled: bool, - vision_detail: ImagePromptMessageContent.DETAIL, -) -> None: - if not vision_enabled or not files: - return - - file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files] - if ( - prompt_messages - and isinstance(prompt_messages[-1], UserPromptMessage) - and isinstance(prompt_messages[-1].content, list) - ): - existing_contents = prompt_messages[-1].content - assert isinstance(existing_contents, list) - prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents) - else: - prompt_messages.append(UserPromptMessage(content=file_prompts)) - - -def _coerce_resolved_value(raw: str) -> int | float | bool | str: - """Try to restore the original type from a resolved template string. - - Variable references are always resolved to text, but completion params may - expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to - the ``temperature`` parameter). This helper attempts a JSON parse so that - ``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not - valid JSON literals are returned as-is. - """ - stripped = raw.strip() - if not stripped: - return raw - - try: - parsed: object = json.loads(stripped) - except (json.JSONDecodeError, ValueError): - return raw - - if isinstance(parsed, (int, float, bool)): - return parsed - return raw - - -def resolve_completion_params_variables( - completion_params: Mapping[str, Any], - variable_pool: VariablePool, -) -> dict[str, Any]: - """Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params. - - Security notes: - - Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to - prevent denial-of-service through excessively large variable payloads. - - This follows the same ``VariablePool.convert_template`` pattern used across - Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream - model plugin receives these values as structured JSON key-value pairs — they - are never concatenated into raw HTTP headers or SQL queries. - - Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are - restored to their native type rather than sent as a bare string. - """ - resolved: dict[str, Any] = {} - for key, value in completion_params.items(): - if isinstance(value, str) and VARIABLE_PATTERN.search(value): - segment_group = variable_pool.convert_template(value) - text = segment_group.text - if len(text) > MAX_RESOLVED_VALUE_LENGTH: - logger.warning( - "Resolved value for param '%s' truncated from %d to %d chars", - key, - len(text), - MAX_RESOLVED_VALUE_LENGTH, - ) - text = text[:MAX_RESOLVED_VALUE_LENGTH] - resolved[key] = _coerce_resolved_value(text) - else: - resolved[key] = value - return resolved diff --git a/api/dify_graph/nodes/llm/node.py b/api/dify_graph/nodes/llm/node.py deleted file mode 100644 index a5492aee6b..0000000000 --- a/api/dify_graph/nodes/llm/node.py +++ /dev/null @@ -1,1035 +0,0 @@ -from __future__ import annotations - -import base64 -import io -import json -import logging -import re -import time -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any, Literal - -from sqlalchemy import select - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from core.tools.signature import sign_upload_file -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( - BuiltinNodeTypes, - NodeType, - SystemVariableKey, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities import ( - ImagePromptMessageContent, - PromptMessage, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkWithStructuredOutput, - LLMResultWithStructuredOutput, - LLMStructuredOutput, - LLMUsage, -) -from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ( - ModelInvokeCompletedEvent, - NodeEventBase, - NodeRunResult, - RunRetrieverResourceEvent, - StreamChunkEvent, - StreamCompletedEvent, -) -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.runtime import VariablePool -from dify_graph.variables import ( - ArrayFileSegment, - ArraySegment, - NoneSegment, - ObjectSegment, - StringSegment, -) -from extensions.ext_database import db -from models.dataset import SegmentAttachmentBinding -from models.model import UploadFile - -from . import llm_utils -from .entities import ( - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - LLMNodeData, -) -from .exc import ( - InvalidContextStructureError, - InvalidVariableTypeError, - LLMNodeError, - VariableNotFoundError, -) -from .file_saver import FileSaverImpl, LLMFileSaver - -if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState - -logger = logging.getLogger(__name__) - - -class LLMNode(Node[LLMNodeData]): - node_type = BuiltinNodeTypes.LLM - - # Compiled regex for extracting blocks (with compatibility for attributes) - _THINK_PATTERN = re.compile(r"]*>(.*?)", re.IGNORECASE | re.DOTALL) - - # Instance attributes specific to LLMNode. - # Output variable for file - _file_outputs: list[File] - - _llm_file_saver: LLMFileSaver - _credentials_provider: CredentialsProvider - _model_factory: ModelFactory - _model_instance: ModelInstance - _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - *, - credentials_provider: CredentialsProvider, - model_factory: ModelFactory, - model_instance: ModelInstance, - http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - self._credentials_provider = credentials_provider - self._model_factory = model_factory - self._model_instance = model_instance - self._memory = memory - self._template_renderer = template_renderer - - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) - self._llm_file_saver = llm_file_saver - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - node_inputs: dict[str, Any] = {} - process_data: dict[str, Any] = {} - result_text = "" - clean_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - reasoning_content = None - variable_pool = self.graph_runtime_state.variable_pool - - try: - # init messages template - self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template) - - # fetch variables and fetch values from variable pool - inputs = self._fetch_inputs(node_data=self.node_data) - - # fetch jinja2 inputs - jinja_inputs = self._fetch_jinja_inputs(node_data=self.node_data) - - # merge inputs - inputs.update(jinja_inputs) - - # fetch files - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=self.node_data.vision.configs.variable_selector, - ) - if self.node_data.vision.enabled - else [] - ) - - if files: - node_inputs["#files#"] = [file.to_dict() for file in files] - - # fetch context value - generator = self._fetch_context(node_data=self.node_data) - context = None - context_files: list[File] = [] - for event in generator: - context = event.context - context_files = event.context_files or [] - yield event - if context: - node_inputs["#context#"] = context - - if context_files: - node_inputs["#context_files#"] = [file.model_dump() for file in context_files] - - # fetch model config - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - model_name = model_instance.model_name - model_provider = model_instance.provider - model_stop = model_instance.stop - - memory = self._memory - - query: str | None = None - if self.node_data.memory: - query = self.node_data.memory.query_prompt_template - if not query and ( - query_variable := variable_pool.get((SYSTEM_VARIABLE_NODE_ID, SystemVariableKey.QUERY)) - ): - query = query_variable.text - - prompt_messages, stop = LLMNode.fetch_prompt_messages( - sys_query=query, - sys_files=files, - context=context, - memory=memory, - model_instance=model_instance, - stop=model_stop, - prompt_template=self.node_data.prompt_template, - memory_config=self.node_data.memory, - vision_enabled=self.node_data.vision.enabled, - vision_detail=self.node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=self.node_data.prompt_config.jinja2_variables, - context_files=context_files, - template_renderer=self._template_renderer, - ) - - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - user_id=self.require_dify_context().user_id, - structured_output_enabled=self.node_data.structured_output_enabled, - structured_output=self.node_data.structured_output, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - reasoning_format=self.node_data.reasoning_format, - ) - - structured_output: LLMStructuredOutput | None = None - - for event in generator: - if isinstance(event, StreamChunkEvent): - yield event - elif isinstance(event, ModelInvokeCompletedEvent): - # Raw text - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - reasoning_content = event.reasoning_content or "" - - # For downstream nodes, determine clean text based on reasoning_format - if self.node_data.reasoning_format == "tagged": - # Keep tags for backward compatibility - clean_text = result_text - else: - # Extract clean text from tags - clean_text, _ = LLMNode._split_reasoning(result_text, self.node_data.reasoning_format) - - # Process structured output if available from the event. - structured_output = ( - LLMStructuredOutput(structured_output=event.structured_output) - if event.structured_output - else None - ) - - break - elif isinstance(event, LLMStructuredOutput): - structured_output = event - - process_data = { - "model_mode": self.node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=self.node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_provider, - "model_name": model_name, - } - - outputs = { - "text": clean_text, - "reasoning_content": reasoning_content, - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - } - if structured_output: - outputs["structured_output"] = structured_output.structured_output - if self._file_outputs: - outputs["files"] = ArrayFileSegment(value=self._file_outputs) - - # Send final chunk event to indicate streaming is complete - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=node_inputs, - process_data=process_data, - outputs=outputs, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - ) - except ValueError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - except Exception as e: - logger.exception("error while executing llm node") - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - inputs=node_inputs, - process_data=process_data, - error_type=type(e).__name__, - llm_usage=usage, - ) - ) - - @staticmethod - def invoke_llm( - *, - model_instance: ModelInstance, - prompt_messages: Sequence[PromptMessage], - stop: Sequence[str] | None = None, - user_id: str, - structured_output_enabled: bool, - structured_output: Mapping[str, Any] | None = None, - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - reasoning_format: Literal["separated", "tagged"] = "tagged", - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - model_parameters = model_instance.parameters - invoke_model_parameters = dict(model_parameters) - - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - if structured_output_enabled: - output_schema = LLMNode.fetch_structured_output_schema( - structured_output=structured_output or {}, - ) - request_start_time = time.perf_counter() - - invoke_result = invoke_llm_with_structured_output( - provider=model_instance.provider, - model_schema=model_schema, - model_instance=model_instance, - prompt_messages=prompt_messages, - json_schema=output_schema, - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, - ) - else: - request_start_time = time.perf_counter() - - invoke_result = model_instance.invoke_llm( - prompt_messages=list(prompt_messages), - model_parameters=invoke_model_parameters, - stop=list(stop or []), - stream=True, - user=user_id, - ) - - return LLMNode.handle_invoke_result( - invoke_result=invoke_result, - file_saver=file_saver, - file_outputs=file_outputs, - node_id=node_id, - node_type=node_type, - reasoning_format=reasoning_format, - request_start_time=request_start_time, - ) - - @staticmethod - def handle_invoke_result( - *, - invoke_result: LLMResult | Generator[LLMResultChunk | LLMStructuredOutput, None, None], - file_saver: LLMFileSaver, - file_outputs: list[File], - node_id: str, - node_type: NodeType, - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_start_time: float | None = None, - ) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]: - # For blocking mode - if isinstance(invoke_result, LLMResult): - duration = None - if request_start_time is not None: - duration = time.perf_counter() - request_start_time - invoke_result.usage.latency = round(duration, 3) - event = LLMNode.handle_blocking_result( - invoke_result=invoke_result, - saver=file_saver, - file_outputs=file_outputs, - reasoning_format=reasoning_format, - request_latency=duration, - ) - yield event - return - - # For streaming mode - model = "" - prompt_messages: list[PromptMessage] = [] - - usage = LLMUsage.empty_usage() - finish_reason = None - full_text_buffer = io.StringIO() - - # Initialize streaming metrics tracking - start_time = request_start_time if request_start_time is not None else time.perf_counter() - first_token_time = None - has_content = False - - collected_structured_output = None # Collect structured_output from streaming chunks - # Consume the invoke result and handle generator exception - try: - for result in invoke_result: - if isinstance(result, LLMResultChunkWithStructuredOutput): - # Collect structured_output from the chunk - if result.structured_output is not None: - collected_structured_output = dict(result.structured_output) - yield result - if isinstance(result, LLMResultChunk): - contents = result.delta.message.content - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=contents, - file_saver=file_saver, - file_outputs=file_outputs, - ): - # Detect first token for TTFT calculation - if text_part and not has_content: - first_token_time = time.perf_counter() - has_content = True - - full_text_buffer.write(text_part) - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=text_part, - is_final=False, - ) - - # Update the whole metadata - if not model and result.model: - model = result.model - if len(prompt_messages) == 0: - # TODO(QuantumGhost): it seems that this update has no visable effect. - # What's the purpose of the line below? - prompt_messages = list(result.prompt_messages) - if usage.prompt_tokens == 0 and result.delta.usage: - usage = result.delta.usage - if finish_reason is None and result.delta.finish_reason: - finish_reason = result.delta.finish_reason - except OutputParserError as e: - raise LLMNodeError(f"Failed to parse structured output: {e}") - - # Extract reasoning content from tags in the main text - full_text = full_text_buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - # Calculate streaming metrics - end_time = time.perf_counter() - total_duration = end_time - start_time - usage.latency = round(total_duration, 3) - if has_content and first_token_time: - gen_ai_server_time_to_first_token = first_token_time - start_time - llm_streaming_time_to_generate = end_time - first_token_time - usage.time_to_first_token = round(gen_ai_server_time_to_first_token, 3) - usage.time_to_generate = round(llm_streaming_time_to_generate, 3) - - yield ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=usage, - finish_reason=finish_reason, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if collected from streaming chunks - structured_output=collected_structured_output, - ) - - @staticmethod - def _image_file_to_markdown(file: File, /): - text_chunk = f"![]({file.generate_url()})" - return text_chunk - - @classmethod - def _split_reasoning( - cls, text: str, reasoning_format: Literal["separated", "tagged"] = "tagged" - ) -> tuple[str, str]: - """ - Split reasoning content from text based on reasoning_format strategy. - - Args: - text: Full text that may contain blocks - reasoning_format: Strategy for handling reasoning content - - "separated": Remove tags and return clean text + reasoning_content field - - "tagged": Keep tags in text, return empty reasoning_content - - Returns: - tuple of (clean_text, reasoning_content) - """ - - if reasoning_format == "tagged": - return text, "" - - # Find all ... blocks (case-insensitive) - matches = cls._THINK_PATTERN.findall(text) - - # Extract reasoning content from all blocks - reasoning_content = "\n".join(match.strip() for match in matches) if matches else "" - - # Remove all ... blocks from original text - clean_text = cls._THINK_PATTERN.sub("", text) - - # Clean up extra whitespace - clean_text = re.sub(r"\n\s*\n", "\n\n", clean_text).strip() - - # Separated mode: always return clean text and reasoning_content - return clean_text, reasoning_content or "" - - def _transform_chat_messages( - self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, / - ) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate: - if isinstance(messages, LLMNodeCompletionModelPromptTemplate): - if messages.edition_type == "jinja2" and messages.jinja2_text: - messages.text = messages.jinja2_text - - return messages - - for message in messages: - if message.edition_type == "jinja2" and message.jinja2_text: - message.text = message.jinja2_text - - return messages - - def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]: - variables: dict[str, Any] = {} - - if not node_data.prompt_config: - return variables - - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_name = variable_selector.variable - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - - def parse_dict(input_dict: Mapping[str, Any]) -> str: - """ - Parse dict into string - """ - # check if it's a context structure - if "metadata" in input_dict and "_source" in input_dict["metadata"] and "content" in input_dict: - return str(input_dict["content"]) - - # else, parse the dict - try: - return json.dumps(input_dict, ensure_ascii=False) - except Exception: - return str(input_dict) - - if isinstance(variable, ArraySegment): - result = "" - for item in variable.value: - if isinstance(item, dict): - result += parse_dict(item) - else: - result += str(item) - result += "\n" - value = result.strip() - elif isinstance(variable, ObjectSegment): - value = parse_dict(variable.value) - else: - value = variable.text - - variables[variable_name] = value - - return variables - - def _fetch_inputs(self, node_data: LLMNodeData) -> dict[str, Any]: - inputs = {} - prompt_template = node_data.prompt_template - - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, CompletionModelPromptTemplate): - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - - for variable_selector in variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - inputs[variable_selector.variable] = "" - inputs[variable_selector.variable] = variable.to_object() - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - if variable is None: - raise VariableNotFoundError(f"Variable {variable_selector.variable} not found") - if isinstance(variable, NoneSegment): - continue - inputs[variable_selector.variable] = variable.to_object() - - return inputs - - def _fetch_context(self, node_data: LLMNodeData): - if not node_data.context.enabled: - return - - if not node_data.context.variable_selector: - return - - context_value_variable = self.graph_runtime_state.variable_pool.get(node_data.context.variable_selector) - if context_value_variable: - if isinstance(context_value_variable, StringSegment): - yield RunRetrieverResourceEvent( - retriever_resources=[], context=context_value_variable.value, context_files=[] - ) - elif isinstance(context_value_variable, ArraySegment): - context_str = "" - original_retriever_resource: list[dict[str, Any]] = [] - context_files: list[File] = [] - for item in context_value_variable.value: - if isinstance(item, str): - context_str += item + "\n" - else: - if "content" not in item: - raise InvalidContextStructureError(f"Invalid context structure: {item}") - - if item.get("summary"): - context_str += item["summary"] + "\n" - context_str += item["content"] + "\n" - - retriever_resource = self._convert_to_original_retriever_resource(item) - if retriever_resource: - original_retriever_resource.append(retriever_resource) - segment_id = retriever_resource.get("segment_id") - if not segment_id: - continue - attachments_with_bindings = db.session.execute( - select(SegmentAttachmentBinding, UploadFile) - .join(UploadFile, UploadFile.id == SegmentAttachmentBinding.attachment_id) - .where( - SegmentAttachmentBinding.segment_id == segment_id, - ) - ).all() - if attachments_with_bindings: - for _, upload_file in attachments_with_bindings: - attachment_info = File( - id=upload_file.id, - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=self.require_dify_context().tenant_id, - type=FileType.IMAGE, - transfer_method=FileTransferMethod.LOCAL_FILE, - remote_url=upload_file.source_url, - related_id=upload_file.id, - size=upload_file.size, - storage_key=upload_file.key, - url=sign_upload_file(upload_file.id, upload_file.extension), - ) - context_files.append(attachment_info) - yield RunRetrieverResourceEvent( - retriever_resources=original_retriever_resource, - context=context_str.strip(), - context_files=context_files, - ) - - def _convert_to_original_retriever_resource(self, context_dict: dict) -> dict[str, Any] | None: - if ( - "metadata" in context_dict - and "_source" in context_dict["metadata"] - and context_dict["metadata"]["_source"] == "knowledge" - ): - metadata = context_dict.get("metadata", {}) - - return { - "position": metadata.get("position"), - "dataset_id": metadata.get("dataset_id"), - "dataset_name": metadata.get("dataset_name"), - "document_id": metadata.get("document_id"), - "document_name": metadata.get("document_name"), - "data_source_type": metadata.get("data_source_type"), - "segment_id": metadata.get("segment_id"), - "retriever_from": metadata.get("retriever_from"), - "score": metadata.get("score"), - "hit_count": metadata.get("segment_hit_count"), - "word_count": metadata.get("segment_word_count"), - "segment_position": metadata.get("segment_position"), - "index_node_hash": metadata.get("segment_index_node_hash"), - "content": context_dict.get("content"), - "page": metadata.get("page"), - "doc_metadata": metadata.get("doc_metadata"), - "files": context_dict.get("files"), - "summary": context_dict.get("summary"), - } - - return None - - @staticmethod - def fetch_prompt_messages( - *, - sys_query: str | None = None, - sys_files: Sequence[File], - context: str | None = None, - memory: PromptMessageMemory | None = None, - model_instance: ModelInstance, - prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, - stop: Sequence[str] | None = None, - memory_config: MemoryConfig | None = None, - vision_enabled: bool = False, - vision_detail: ImagePromptMessageContent.DETAIL, - variable_pool: VariablePool, - jinja2_variables: Sequence[VariableSelector], - context_files: list[File] | None = None, - template_renderer: TemplateRenderer | None = None, - ) -> tuple[Sequence[PromptMessage], Sequence[str] | None]: - return llm_utils.fetch_prompt_messages( - sys_query=sys_query, - sys_files=sys_files, - context=context, - memory=memory, - model_instance=model_instance, - prompt_template=prompt_template, - stop=stop, - memory_config=memory_config, - vision_enabled=vision_enabled, - vision_detail=vision_detail, - variable_pool=variable_pool, - jinja2_variables=jinja2_variables, - context_files=context_files, - template_renderer=template_renderer, - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LLMNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - _ = graph_config # Explicitly mark as unused - prompt_template = node_data.prompt_template - variable_selectors = [] - if isinstance(prompt_template, list): - for prompt in prompt_template: - if prompt.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt.text) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type != "jinja2": - variable_template_parser = VariableTemplateParser(template=prompt_template.text) - variable_selectors = variable_template_parser.extract_variable_selectors() - else: - raise InvalidVariableTypeError(f"Invalid prompt template type: {type(prompt_template)}") - - variable_mapping: dict[str, Any] = {} - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - memory = node_data.memory - if memory and memory.query_prompt_template: - query_variable_selectors = VariableTemplateParser( - template=memory.query_prompt_template - ).extract_variable_selectors() - for variable_selector in query_variable_selectors: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - if node_data.context.enabled: - variable_mapping["#context#"] = node_data.context.variable_selector - - if node_data.vision.enabled: - variable_mapping["#files#"] = node_data.vision.configs.variable_selector - - if node_data.memory: - variable_mapping["#sys.query#"] = ["sys", SystemVariableKey.QUERY] - - if node_data.prompt_config: - enable_jinja = False - - if isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate): - if prompt_template.edition_type == "jinja2": - enable_jinja = True - else: - for prompt in prompt_template: - if prompt.edition_type == "jinja2": - enable_jinja = True - break - - if enable_jinja: - for variable_selector in node_data.prompt_config.jinja2_variables or []: - variable_mapping[variable_selector.variable] = variable_selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "type": "llm", - "config": { - "prompt_templates": { - "chat_model": { - "prompts": [ - {"role": "system", "text": "You are a helpful AI assistant.", "edition_type": "basic"} - ] - }, - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "prompt": { - "text": "Here are the chat histories between human and assistant, inside " - " XML tags.\n\n\n{{" - "#histories#}}\n\n\n\nHuman: {{#sys.query#}}\n\nAssistant:", - "edition_type": "basic", - }, - "stop": ["Human:"], - }, - } - }, - } - - @staticmethod - def handle_list_messages( - *, - messages: Sequence[LLMNodeChatModelMessage], - context: str | None, - jinja2_variables: Sequence[VariableSelector], - variable_pool: VariablePool, - vision_detail_config: ImagePromptMessageContent.DETAIL, - template_renderer: TemplateRenderer | None = None, - ) -> Sequence[PromptMessage]: - return llm_utils.handle_list_messages( - messages=messages, - context=context, - jinja2_variables=jinja2_variables, - variable_pool=variable_pool, - vision_detail_config=vision_detail_config, - template_renderer=template_renderer, - ) - - @staticmethod - def handle_blocking_result( - *, - invoke_result: LLMResult | LLMResultWithStructuredOutput, - saver: LLMFileSaver, - file_outputs: list[File], - reasoning_format: Literal["separated", "tagged"] = "tagged", - request_latency: float | None = None, - ) -> ModelInvokeCompletedEvent: - buffer = io.StringIO() - for text_part in LLMNode._save_multimodal_output_and_convert_result_to_markdown( - contents=invoke_result.message.content, - file_saver=saver, - file_outputs=file_outputs, - ): - buffer.write(text_part) - - # Extract reasoning content from tags in the main text - full_text = buffer.getvalue() - - if reasoning_format == "tagged": - # Keep tags in text for backward compatibility - clean_text = full_text - reasoning_content = "" - else: - # Extract clean text and reasoning from tags - clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format) - - event = ModelInvokeCompletedEvent( - # Use clean_text for separated mode, full_text for tagged mode - text=clean_text if reasoning_format == "separated" else full_text, - usage=invoke_result.usage, - finish_reason=None, - # Reasoning content for workflow variables and downstream nodes - reasoning_content=reasoning_content, - # Pass structured output if enabled - structured_output=getattr(invoke_result, "structured_output", None), - ) - if request_latency is not None: - event.usage.latency = round(request_latency, 3) - return event - - @staticmethod - def save_multimodal_image_output( - *, - content: ImagePromptMessageContent, - file_saver: LLMFileSaver, - ) -> File: - """_save_multimodal_output saves multi-modal contents generated by LLM plugins. - - There are two kinds of multimodal outputs: - - - Inlined data encoded in base64, which would be saved to storage directly. - - Remote files referenced by an url, which would be downloaded and then saved to storage. - - Currently, only image files are supported. - """ - if content.url != "": - saved_file = file_saver.save_remote_url(content.url, FileType.IMAGE) - else: - saved_file = file_saver.save_binary_string( - data=base64.b64decode(content.base64_data), - mime_type=content.mime_type, - file_type=FileType.IMAGE, - ) - return saved_file - - @staticmethod - def fetch_structured_output_schema( - *, - structured_output: Mapping[str, Any], - ) -> dict[str, Any]: - """ - Fetch the structured output schema from the node data. - - Returns: - dict[str, Any]: The structured output schema - """ - if not structured_output: - raise LLMNodeError("Please provide a valid structured output schema") - structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False) - if not structured_output_schema: - raise LLMNodeError("Please provide a valid structured output schema") - - try: - schema = json.loads(structured_output_schema) - if not isinstance(schema, dict): - raise LLMNodeError("structured_output_schema must be a JSON object") - return schema - except json.JSONDecodeError: - raise LLMNodeError("structured_output_schema is not valid JSON format") - - @staticmethod - def _save_multimodal_output_and_convert_result_to_markdown( - *, - contents: str | list[PromptMessageContentUnionTypes] | None, - file_saver: LLMFileSaver, - file_outputs: list[File], - ) -> Generator[str, None, None]: - """Convert intermediate prompt messages into strings and yield them to the caller. - - If the messages contain non-textual content (e.g., multimedia like images or videos), - it will be saved separately, and the corresponding Markdown representation will - be yielded to the caller. - """ - - # NOTE(QuantumGhost): This function should yield results to the caller immediately - # whenever new content or partial content is available. Avoid any intermediate buffering - # of results. Additionally, do not yield empty strings; instead, yield from an empty list - # if necessary. - if contents is None: - yield from [] - return - if isinstance(contents, str): - yield contents - else: - for item in contents: - if isinstance(item, TextPromptMessageContent): - yield item.data - elif isinstance(item, ImagePromptMessageContent): - file = LLMNode.save_multimodal_image_output( - content=item, - file_saver=file_saver, - ) - file_outputs.append(file) - yield LLMNode._image_file_to_markdown(file) - else: - logger.warning("unknown item type encountered, type=%s", type(item)) - yield str(item) - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled - - @property - def model_instance(self) -> ModelInstance: - return self._model_instance diff --git a/api/dify_graph/nodes/llm/protocols.py b/api/dify_graph/nodes/llm/protocols.py deleted file mode 100644 index 9e95d341c9..0000000000 --- a/api/dify_graph/nodes/llm/protocols.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from core.model_manager import ModelInstance - - -class CredentialsProvider(Protocol): - """Port for loading runtime credentials for a provider/model pair.""" - - def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]: - """Return credentials for the target provider/model or raise a domain error.""" - ... - - -class ModelFactory(Protocol): - """Port for creating initialized LLM model instances for execution.""" - - def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance: - """Create a model instance that is ready for schema lookup and invocation.""" - ... - - -class TemplateRenderer(Protocol): - """Port for rendering prompt templates used by LLM-compatible nodes.""" - - def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str: - """Render the given Jinja2 template into plain text.""" - ... diff --git a/api/dify_graph/nodes/loop/__init__.py b/api/dify_graph/nodes/loop/__init__.py deleted file mode 100644 index 9fe695607b..0000000000 --- a/api/dify_graph/nodes/loop/__init__.py +++ /dev/null @@ -1,6 +0,0 @@ -from .entities import LoopNodeData -from .loop_end_node import LoopEndNode -from .loop_node import LoopNode -from .loop_start_node import LoopStartNode - -__all__ = ["LoopEndNode", "LoopNode", "LoopNodeData", "LoopStartNode"] diff --git a/api/dify_graph/nodes/loop/entities.py b/api/dify_graph/nodes/loop/entities.py deleted file mode 100644 index f0bfad5a0f..0000000000 --- a/api/dify_graph/nodes/loop/entities.py +++ /dev/null @@ -1,107 +0,0 @@ -from enum import StrEnum -from typing import Annotated, Any, Literal - -from pydantic import AfterValidator, BaseModel, Field, field_validator - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState -from dify_graph.utils.condition.entities import Condition -from dify_graph.variables.types import SegmentType - -_VALID_VAR_TYPE = frozenset( - [ - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.OBJECT, - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - ] -) - - -def _is_valid_var_type(seg_type: SegmentType) -> SegmentType: - if seg_type not in _VALID_VAR_TYPE: - raise ValueError(...) - return seg_type - - -class LoopVariableData(BaseModel): - """ - Loop Variable Data. - """ - - label: str - var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)] - value_type: Literal["variable", "constant"] - value: Any | list[str] | None = None - - -class LoopNodeData(BaseLoopNodeData): - type: NodeType = BuiltinNodeTypes.LOOP - loop_count: int # Maximum number of loops - break_conditions: list[Condition] # Conditions to break the loop - logical_operator: Literal["and", "or"] - loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData]) - outputs: dict[str, Any] = Field(default_factory=dict) - - @field_validator("outputs", mode="before") - @classmethod - def validate_outputs(cls, v): - if v is None: - return {} - return v - - -class LoopStartNodeData(BaseNodeData): - """ - Loop Start Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_START - - -class LoopEndNodeData(BaseNodeData): - """ - Loop End Node Data. - """ - - type: NodeType = BuiltinNodeTypes.LOOP_END - - -class LoopState(BaseLoopState): - """ - Loop State. - """ - - outputs: list[Any] = Field(default_factory=list) - current_output: Any = None - - class MetaData(BaseLoopState.MetaData): - """ - Data. - """ - - loop_length: int - - def get_last_output(self) -> Any: - """ - Get last output. - """ - if self.outputs: - return self.outputs[-1] - return None - - def get_current_output(self) -> Any: - """ - Get current output. - """ - return self.current_output - - -class LoopCompletedReason(StrEnum): - LOOP_BREAK = "loop_break" - LOOP_COMPLETED = "loop_completed" diff --git a/api/dify_graph/nodes/loop/loop_end_node.py b/api/dify_graph/nodes/loop/loop_end_node.py deleted file mode 100644 index 0287708fb3..0000000000 --- a/api/dify_graph/nodes/loop/loop_end_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopEndNodeData - - -class LoopEndNode(Node[LoopEndNodeData]): - """ - Loop End Node. - """ - - node_type = BuiltinNodeTypes.LOOP_END - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/dify_graph/nodes/loop/loop_node.py b/api/dify_graph/nodes/loop/loop_node.py deleted file mode 100644 index 3c546ffa23..0000000000 --- a/api/dify_graph/nodes/loop/loop_node.py +++ /dev/null @@ -1,435 +0,0 @@ -import contextlib -import json -import logging -from collections.abc import Callable, Generator, Mapping, Sequence -from datetime import datetime -from typing import TYPE_CHECKING, Any, Literal, cast - -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.graph_events import ( - GraphNodeEventBase, - GraphRunFailedEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import ( - LoopFailedEvent, - LoopNextEvent, - LoopStartedEvent, - LoopSucceededEvent, - NodeEventBase, - NodeRunResult, - StreamCompletedEvent, -) -from dify_graph.nodes.base import LLMUsageTrackingMixin -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData -from dify_graph.utils.condition.processor import ConditionProcessor -from dify_graph.variables import Segment, SegmentType -from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable -from libs.datetime_utils import naive_utc_now - -if TYPE_CHECKING: - from dify_graph.graph_engine import GraphEngine - -logger = logging.getLogger(__name__) - - -class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]): - """ - Loop Node. - """ - - node_type = BuiltinNodeTypes.LOOP - execution_type = NodeExecutionType.CONTAINER - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> Generator: - """Run the node.""" - # Get inputs - loop_count = self.node_data.loop_count - break_conditions = self.node_data.break_conditions - logical_operator = self.node_data.logical_operator - - inputs = {"loop_count": loop_count} - - if not self.node_data.start_node_id: - raise ValueError(f"field start_node_id in loop {self._node_id} not found") - - root_node_id = self.node_data.start_node_id - - # Initialize loop variables in the original variable pool - loop_variable_selectors = {} - if self.node_data.loop_variables: - value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = { - "constant": lambda var: self._get_segment_for_constant(var.var_type, var.value), - "variable": lambda var: ( - self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None - ), - } - for loop_variable in self.node_data.loop_variables: - if loop_variable.value_type not in value_processor: - raise ValueError( - f"Invalid value type '{loop_variable.value_type}' for loop variable {loop_variable.label}" - ) - - processed_segment = value_processor[loop_variable.value_type](loop_variable) - if not processed_segment: - raise ValueError(f"Invalid value for loop variable {loop_variable.label}") - variable_selector = [self._node_id, loop_variable.label] - variable = segment_to_variable(segment=processed_segment, selector=variable_selector) - self.graph_runtime_state.variable_pool.add(variable_selector, variable.value) - loop_variable_selectors[loop_variable.label] = variable_selector - inputs[loop_variable.label] = processed_segment.value - - start_at = naive_utc_now() - condition_processor = ConditionProcessor() - - loop_duration_map: dict[str, float] = {} - single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output - loop_usage = LLMUsage.empty_usage() - loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id) - - # Start Loop event - yield LoopStartedEvent( - start_at=start_at, - inputs=inputs, - metadata={"loop_length": loop_count}, - ) - - try: - reach_break_condition = False - if break_conditions: - with contextlib.suppress(ValueError): - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - - if reach_break_condition: - loop_count = 0 - - for i in range(loop_count): - # Clear stale variables from previous loop iterations to avoid streaming old values - self._clear_loop_subgraph_variables(loop_node_ids) - graph_engine = self._create_graph_engine(start_at=start_at, root_node_id=root_node_id) - - loop_start_time = naive_utc_now() - reach_break_node = yield from self._run_single_loop(graph_engine=graph_engine, current_index=i) - # Track loop duration - loop_duration_map[str(i)] = (naive_utc_now() - loop_start_time).total_seconds() - - # Accumulate outputs from the sub-graph's response nodes - for key, value in graph_engine.graph_runtime_state.outputs.items(): - if key == "answer": - # Concatenate answer outputs with newline - existing_answer = self.graph_runtime_state.get_output("answer", "") - if existing_answer: - self.graph_runtime_state.set_output("answer", f"{existing_answer}{value}") - else: - self.graph_runtime_state.set_output("answer", value) - else: - # For other outputs, just update - self.graph_runtime_state.set_output(key, value) - - # Accumulate usage from the sub-graph execution - loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage) - - # Collect loop variable values after iteration - single_loop_variable = {} - for key, selector in loop_variable_selectors.items(): - segment = self.graph_runtime_state.variable_pool.get(selector) - single_loop_variable[key] = segment.value if segment else None - - single_loop_variable_map[str(i)] = single_loop_variable - - if reach_break_node: - break - - if break_conditions: - _, _, reach_break_condition = condition_processor.process_conditions( - variable_pool=self.graph_runtime_state.variable_pool, - conditions=break_conditions, - operator=logical_operator, - ) - if reach_break_condition: - break - - yield LoopNextEvent( - index=i + 1, - pre_loop_output=self.node_data.outputs, - ) - - self._accumulate_usage(loop_usage) - # Loop completed successfully - yield LoopSucceededEvent( - start_at=start_at, - inputs=inputs, - outputs=self.node_data.outputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.COMPLETED_REASON: ( - LoopCompletedReason.LOOP_BREAK - if reach_break_condition - else LoopCompletedReason.LOOP_COMPLETED.value - ), - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - outputs=self.node_data.outputs, - inputs=inputs, - llm_usage=loop_usage, - ) - ) - - except Exception as e: - self._accumulate_usage(loop_usage) - yield LoopFailedEvent( - start_at=start_at, - inputs=inputs, - steps=loop_count, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - "completed_reason": "error", - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - error=str(e), - ) - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error=str(e), - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: loop_usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: loop_usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: loop_usage.currency, - WorkflowNodeExecutionMetadataKey.LOOP_DURATION_MAP: loop_duration_map, - WorkflowNodeExecutionMetadataKey.LOOP_VARIABLE_MAP: single_loop_variable_map, - }, - llm_usage=loop_usage, - ) - ) - - def _run_single_loop( - self, - *, - graph_engine: "GraphEngine", - current_index: int, - ) -> Generator[NodeEventBase | GraphNodeEventBase, None, bool]: - reach_break_node = False - for event in graph_engine.run(): - if isinstance(event, GraphNodeEventBase): - self._append_loop_info_to_event(event=event, loop_run_index=current_index) - - if isinstance(event, GraphNodeEventBase) and event.node_type == BuiltinNodeTypes.LOOP_START: - continue - if isinstance(event, GraphNodeEventBase): - yield event - if isinstance(event, NodeRunSucceededEvent) and event.node_type == BuiltinNodeTypes.LOOP_END: - reach_break_node = True - if isinstance(event, GraphRunFailedEvent): - raise Exception(event.error) - - for loop_var in self.node_data.loop_variables or []: - key, sel = loop_var.label, [self._node_id, loop_var.label] - segment = self.graph_runtime_state.variable_pool.get(sel) - self.node_data.outputs[key] = segment.value if segment else None - self.node_data.outputs["loop_round"] = current_index + 1 - - return reach_break_node - - def _append_loop_info_to_event( - self, - event: GraphNodeEventBase, - loop_run_index: int, - ): - event.in_loop_id = self._node_id - loop_metadata = { - WorkflowNodeExecutionMetadataKey.LOOP_ID: self._node_id, - WorkflowNodeExecutionMetadataKey.LOOP_INDEX: loop_run_index, - } - - current_metadata = event.node_run_result.metadata - if WorkflowNodeExecutionMetadataKey.LOOP_ID not in current_metadata: - event.node_run_result.metadata = {**current_metadata, **loop_metadata} - - def _clear_loop_subgraph_variables(self, loop_node_ids: set[str]) -> None: - """ - Remove variables produced by loop sub-graph nodes from previous iterations. - - Keeping stale variables causes a freshly created response coordinator in the - next iteration to fall back to outdated values when no stream chunks exist. - """ - variable_pool = self.graph_runtime_state.variable_pool - for node_id in loop_node_ids: - variable_pool.remove([node_id]) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: LoopNodeData, - ) -> Mapping[str, Sequence[str]]: - variable_mapping = {} - - # Extract loop node IDs statically from graph_config - - loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id) - - # Get node configs from graph_config - node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node} - for sub_node_id, sub_node_config in node_configs.items(): - if sub_node_config.get("data", {}).get("loop_id") != node_id: - continue - - # variable selector to variable mapping - try: - typed_sub_node_config = NodeConfigDictAdapter.validate_python(sub_node_config) - node_type = typed_sub_node_config["data"].type - node_mapping = Node.get_node_type_classes_mapping() - if node_type not in node_mapping: - continue - node_version = str(typed_sub_node_config["data"].version) - node_cls = node_mapping[node_type][node_version] - - sub_node_variable_mapping = node_cls.extract_variable_selector_to_variable_mapping( - graph_config=graph_config, config=typed_sub_node_config - ) - sub_node_variable_mapping = cast(dict[str, Sequence[str]], sub_node_variable_mapping) - except NotImplementedError: - sub_node_variable_mapping = {} - - # remove loop variables - sub_node_variable_mapping = { - sub_node_id + "." + key: value - for key, value in sub_node_variable_mapping.items() - if value[0] != node_id - } - - variable_mapping.update(sub_node_variable_mapping) - - for loop_variable in node_data.loop_variables or []: - if loop_variable.value_type == "variable": - assert loop_variable.value is not None, "Loop variable value must be provided for variable type" - # add loop variable to variable mapping - selector = loop_variable.value - variable_mapping[f"{node_id}.{loop_variable.label}"] = selector - - # remove variable out from loop - variable_mapping = {key: value for key, value in variable_mapping.items() if value[0] not in loop_node_ids} - - return variable_mapping - - @classmethod - def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]: - """ - Extract node IDs that belong to a specific loop from graph configuration. - - This method statically analyzes the graph configuration to find all nodes - that are part of the specified loop, without creating actual node instances. - - :param graph_config: the complete graph configuration - :param loop_node_id: the ID of the loop node - :return: set of node IDs that belong to the loop - """ - loop_node_ids = set() - - # Find all nodes that belong to this loop - nodes = graph_config.get("nodes", []) - for node in nodes: - node_data = node.get("data", {}) - if node_data.get("loop_id") == loop_node_id: - node_id = node.get("id") - if node_id: - loop_node_ids.add(node_id) - - return loop_node_ids - - @staticmethod - def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment: - """Get the appropriate segment type for a constant value.""" - # TODO: Refactor for maintainability: - # 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py) - # 2. Consider moving this method to LoopVariableData class for better encapsulation - if not var_type.is_array_type() or var_type == SegmentType.ARRAY_BOOLEAN: - value = original_value - elif var_type in [ - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_STRING, - ]: - if original_value and isinstance(original_value, str): - value = json.loads(original_value) - else: - logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type) - value = [] - else: - raise AssertionError("this statement should be unreachable.") - try: - return build_segment_with_type(var_type, value=value) - except TypeMismatchError as type_exc: - # Attempt to parse the value as a JSON-encoded string, if applicable. - if not isinstance(original_value, str): - raise - try: - value = json.loads(original_value) - except ValueError: - raise type_exc - return build_segment_with_type(var_type, value) - - def _create_graph_engine(self, start_at: datetime, root_node_id: str): - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - # Create GraphInitParams for child graph execution. - graph_init_params = GraphInitParams( - workflow_id=self.workflow_id, - graph_config=self.graph_config, - run_context=self.run_context, - call_depth=self.workflow_call_depth, - ) - - # Create a new GraphRuntimeState for this iteration - graph_runtime_state_copy = GraphRuntimeState( - variable_pool=self.graph_runtime_state.variable_pool, - start_at=start_at.timestamp(), - ) - - return self.graph_runtime_state.create_child_engine( - workflow_id=self.workflow_id, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state_copy, - graph_config=self.graph_config, - root_node_id=root_node_id, - ) diff --git a/api/dify_graph/nodes/loop/loop_start_node.py b/api/dify_graph/nodes/loop/loop_start_node.py deleted file mode 100644 index e171b4df2f..0000000000 --- a/api/dify_graph/nodes/loop/loop_start_node.py +++ /dev/null @@ -1,22 +0,0 @@ -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.loop.entities import LoopStartNodeData - - -class LoopStartNode(Node[LoopStartNodeData]): - """ - Loop Start Node. - """ - - node_type = BuiltinNodeTypes.LOOP_START - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - """ - Run the node. - """ - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED) diff --git a/api/dify_graph/nodes/parameter_extractor/__init__.py b/api/dify_graph/nodes/parameter_extractor/__init__.py deleted file mode 100644 index bdbf19a7d3..0000000000 --- a/api/dify_graph/nodes/parameter_extractor/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .parameter_extractor_node import ParameterExtractorNode - -__all__ = ["ParameterExtractorNode"] diff --git a/api/dify_graph/nodes/parameter_extractor/entities.py b/api/dify_graph/nodes/parameter_extractor/entities.py deleted file mode 100644 index 2fb042c16c..0000000000 --- a/api/dify_graph/nodes/parameter_extractor/entities.py +++ /dev/null @@ -1,131 +0,0 @@ -from typing import Annotated, Any, Literal - -from pydantic import ( - BaseModel, - BeforeValidator, - Field, - field_validator, -) - -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm.entities import ModelConfig, VisionConfig -from dify_graph.variables.types import SegmentType - -_OLD_BOOL_TYPE_NAME = "bool" -_OLD_SELECT_TYPE_NAME = "select" - -_VALID_PARAMETER_TYPES = frozenset( - [ - SegmentType.STRING, # "string", - SegmentType.NUMBER, # "number", - SegmentType.BOOLEAN, - SegmentType.ARRAY_STRING, - SegmentType.ARRAY_NUMBER, - SegmentType.ARRAY_OBJECT, - SegmentType.ARRAY_BOOLEAN, - _OLD_BOOL_TYPE_NAME, # old boolean type used by Parameter Extractor node - _OLD_SELECT_TYPE_NAME, # string type with enumeration choices. - ] -) - - -def _validate_type(parameter_type: str) -> SegmentType: - if parameter_type not in _VALID_PARAMETER_TYPES: - raise ValueError(f"type {parameter_type} is not allowd to use in Parameter Extractor node.") - - if parameter_type == _OLD_BOOL_TYPE_NAME: - return SegmentType.BOOLEAN - elif parameter_type == _OLD_SELECT_TYPE_NAME: - return SegmentType.STRING - return SegmentType(parameter_type) - - -class ParameterConfig(BaseModel): - """ - Parameter Config. - """ - - name: str - type: Annotated[SegmentType, BeforeValidator(_validate_type)] - options: list[str] | None = None - description: str - required: bool - - @field_validator("name", mode="before") - @classmethod - def validate_name(cls, value) -> str: - if not value: - raise ValueError("Parameter name is required") - if value in {"__reason", "__is_success"}: - raise ValueError("Invalid parameter name, __reason and __is_success are reserved") - return str(value) - - def is_array_type(self) -> bool: - return self.type.is_array_type() - - def element_type(self) -> SegmentType: - """Return the element type of the parameter. - - Raises a ValueError if the parameter's type is not an array type. - """ - element_type = self.type.element_type() - # At this point, self.type is guaranteed to be one of `ARRAY_STRING`, - # `ARRAY_NUMBER`, `ARRAY_OBJECT`, or `ARRAY_BOOLEAN`. - # - # See: _VALID_PARAMETER_TYPES for reference. - assert element_type is not None, f"the element type should not be None, {self.type=}" - return element_type - - -class ParameterExtractorNodeData(BaseNodeData): - """ - Parameter Extractor Node Data. - """ - - type: NodeType = BuiltinNodeTypes.PARAMETER_EXTRACTOR - model: ModelConfig - query: list[str] - parameters: list[ParameterConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - reasoning_mode: Literal["function_call", "prompt"] - vision: VisionConfig = Field(default_factory=VisionConfig) - - @field_validator("reasoning_mode", mode="before") - @classmethod - def set_reasoning_mode(cls, v) -> str: - return v or "function_call" - - def get_parameter_json_schema(self): - """ - Get parameter json schema. - - :return: parameter json schema - """ - parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []} - - for parameter in self.parameters: - parameter_schema: dict[str, Any] = {"description": parameter.description} - - if parameter.type == SegmentType.STRING: - parameter_schema["type"] = "string" - elif parameter.type.is_array_type(): - parameter_schema["type"] = "array" - element_type = parameter.type.element_type() - if element_type is None: - raise AssertionError("element type should not be None.") - parameter_schema["items"] = {"type": element_type.value} - else: - parameter_schema["type"] = parameter.type - - if parameter.options: - parameter_schema["enum"] = parameter.options - - parameters["properties"][parameter.name] = parameter_schema - - if parameter.required: - parameters["required"].append(parameter.name) - - return parameters diff --git a/api/dify_graph/nodes/parameter_extractor/exc.py b/api/dify_graph/nodes/parameter_extractor/exc.py deleted file mode 100644 index c25b809d1c..0000000000 --- a/api/dify_graph/nodes/parameter_extractor/exc.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Any - -from dify_graph.variables.types import SegmentType - - -class ParameterExtractorNodeError(ValueError): - """Base error for ParameterExtractorNode.""" - - -class InvalidModelTypeError(ParameterExtractorNodeError): - """Raised when the model is not a Large Language Model.""" - - -class ModelSchemaNotFoundError(ParameterExtractorNodeError): - """Raised when the model schema is not found.""" - - -class InvalidInvokeResultError(ParameterExtractorNodeError): - """Raised when the invoke result is invalid.""" - - -class InvalidTextContentTypeError(ParameterExtractorNodeError): - """Raised when the text content type is invalid.""" - - -class InvalidNumberOfParametersError(ParameterExtractorNodeError): - """Raised when the number of parameters is invalid.""" - - -class RequiredParameterMissingError(ParameterExtractorNodeError): - """Raised when a required parameter is missing.""" - - -class InvalidSelectValueError(ParameterExtractorNodeError): - """Raised when a select value is invalid.""" - - -class InvalidNumberValueError(ParameterExtractorNodeError): - """Raised when a number value is invalid.""" - - -class InvalidBoolValueError(ParameterExtractorNodeError): - """Raised when a bool value is invalid.""" - - -class InvalidStringValueError(ParameterExtractorNodeError): - """Raised when a string value is invalid.""" - - -class InvalidArrayValueError(ParameterExtractorNodeError): - """Raised when an array value is invalid.""" - - -class InvalidModelModeError(ParameterExtractorNodeError): - """Raised when the model mode is invalid.""" - - -class InvalidValueTypeError(ParameterExtractorNodeError): - def __init__( - self, - /, - parameter_name: str, - expected_type: SegmentType, - actual_type: SegmentType | None, - value: Any, - ): - message = ( - f"Invalid value for parameter {parameter_name}, expected segment type: {expected_type}, " - f"actual_type: {actual_type}, python_type: {type(value)}, value: {value}" - ) - super().__init__(message) - self.parameter_name = parameter_name - self.expected_type = expected_type - self.actual_type = actual_type - self.value = value diff --git a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py b/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py deleted file mode 100644 index e6e8a44d06..0000000000 --- a/api/dify_graph/nodes/parameter_extractor/parameter_extractor_node.py +++ /dev/null @@ -1,857 +0,0 @@ -import contextlib -import json -import logging -import uuid -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any, cast - -from core.model_manager import ModelInstance -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import File -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - PromptMessage, - PromptMessageRole, - PromptMessageTool, - ToolPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base import variable_template_parser -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.llm import llm_utils -from dify_graph.runtime import VariablePool -from dify_graph.variables.types import ArrayValidation, SegmentType -from factories.variable_factory import build_segment_with_type - -from .entities import ParameterExtractorNodeData -from .exc import ( - InvalidModelModeError, - InvalidModelTypeError, - InvalidNumberOfParametersError, - InvalidSelectValueError, - InvalidTextContentTypeError, - InvalidValueTypeError, - ModelSchemaNotFoundError, - ParameterExtractorNodeError, - RequiredParameterMissingError, -) -from .prompts import ( - CHAT_EXAMPLE, - CHAT_GENERATE_JSON_PROMPT, - CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE, - COMPLETION_GENERATE_JSON_PROMPT, - FUNCTION_CALLING_EXTRACTOR_EXAMPLE, - FUNCTION_CALLING_EXTRACTOR_NAME, - FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT, - FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE, -) - -logger = logging.getLogger(__name__) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory - from dify_graph.runtime import GraphRuntimeState - - -def extract_json(text): - """ - From a given JSON started from '{' or '[' extract the complete JSON object. - """ - stack = [] - for i, c in enumerate(text): - if c in {"{", "["}: - stack.append(c) - elif c in {"}", "]"}: - # check if stack is empty - if not stack: - return text[:i] - # check if the last element in stack is matching - if (c == "}" and stack[-1] == "{") or (c == "]" and stack[-1] == "["): - stack.pop() - if not stack: - return text[: i + 1] - else: - return text[:i] - return None - - -class ParameterExtractorNode(Node[ParameterExtractorNodeData]): - """ - Parameter Extractor Node. - """ - - node_type = BuiltinNodeTypes.PARAMETER_EXTRACTOR - - _model_instance: ModelInstance - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" - _memory: PromptMessageMemory | None - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, - memory: PromptMessageMemory | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._credentials_provider = credentials_provider - self._model_factory = model_factory - self._model_instance = model_instance - self._memory = memory - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - return { - "model": { - "prompt_templates": { - "completion_model": { - "conversation_histories_role": {"user_prefix": "Human", "assistant_prefix": "Assistant"}, - "stop": ["Human:"], - } - } - } - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self): - """ - Run the node. - """ - node_data = self.node_data - variable = self.graph_runtime_state.variable_pool.get(node_data.query) - query = variable.text if variable else "" - - variable_pool = self.graph_runtime_state.variable_pool - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - if not isinstance(model_instance.model_type_instance, LargeLanguageModel): - raise InvalidModelTypeError("Model is not a Large Language Model") - - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - memory = self._memory - - if ( - set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL} - and node_data.reasoning_mode == "function_call" - ): - # use function call - prompt_messages, prompt_message_tools = self._generate_function_call_prompt( - node_data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - else: - # use prompt engineering - prompt_messages = self._generate_prompt_engineering_prompt( - data=node_data, - query=query, - variable_pool=self.graph_runtime_state.variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=node_data.vision.configs.detail, - ) - - prompt_message_tools = [] - - inputs = { - "query": query, - "files": [f.to_dict() for f in files], - "parameters": jsonable_encoder(node_data.parameters), - "instruction": jsonable_encoder(node_data.instruction), - } - - process_data = { - "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": None, - "function": {} if not prompt_message_tools else jsonable_encoder(prompt_message_tools[0]), - "tool_call": None, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - - try: - text, usage, tool_call = self._invoke( - model_instance=model_instance, - prompt_messages=prompt_messages, - tools=prompt_message_tools, - stop=model_instance.stop, - ) - process_data["usage"] = jsonable_encoder(usage) - process_data["tool_call"] = jsonable_encoder(tool_call) - process_data["llm_text"] = text - except ParameterExtractorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": str(e)}, - error=str(e), - metadata={}, - ) - except Exception as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - outputs={"__is_success": 0, "__reason": "Failed to invoke model", "__error": str(e)}, - error=str(e), - metadata={}, - ) - - error = None - - if tool_call: - result = self._extract_json_from_tool_call(tool_call) - else: - result = self._extract_complete_json_response(text) - if not result: - result = self._generate_default_result(node_data) - error = "Failed to extract result from function call or text response, using empty result." - - try: - result = self._validate_result(data=node_data, result=result or {}) - except ParameterExtractorNodeError as e: - error = str(e) - - # transform result into standard format - result = self._transform_result(data=node_data, result=result or {}) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={ - "__is_success": 1 if not error else 0, - "__reason": error, - "__usage": jsonable_encoder(usage), - **result, - }, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - def _invoke( - self, - model_instance: ModelInstance, - prompt_messages: list[PromptMessage], - tools: list[PromptMessageTool], - stop: Sequence[str], - ) -> tuple[str, LLMUsage, AssistantPromptMessage.ToolCall | None]: - invoke_result = model_instance.invoke_llm( - prompt_messages=prompt_messages, - model_parameters=dict(model_instance.parameters), - tools=tools, - stop=list(stop), - stream=False, - user=self.require_dify_context().user_id, - ) - - # handle invoke result - - text = invoke_result.message.get_text_content() - if not isinstance(text, str): - raise InvalidTextContentTypeError(f"Invalid text content type: {type(text)}. Expected str.") - - usage = invoke_result.usage - tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None - - return text, usage, tool_call - - def _generate_function_call_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> tuple[list[PromptMessage], list[PromptMessageTool]]: - """ - Generate function call prompt. - """ - query = FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE.format( - content=query, structure=json.dumps(node_data.get_parameter_json_schema()) - ) - - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_function_calling_prompt_template( - node_data, query, variable_pool, memory, rest_token - ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, - model_instance=model_instance, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add function call messages before last user message - example_messages = [] - for example in FUNCTION_CALLING_EXTRACTOR_EXAMPLE: - id = uuid.uuid4().hex - example_messages.extend( - [ - UserPromptMessage(content=example["user"]["query"]), - AssistantPromptMessage( - content=example["assistant"]["text"], - tool_calls=[ - AssistantPromptMessage.ToolCall( - id=id, - type="function", - function=AssistantPromptMessage.ToolCall.ToolCallFunction( - name=example["assistant"]["function_call"]["name"], - arguments=json.dumps(example["assistant"]["function_call"]["parameters"]), - ), - ) - ], - ), - ToolPromptMessage( - content="Great! You have called the function with the correct parameters.", tool_call_id=id - ), - AssistantPromptMessage( - content="I have extracted the parameters, let's move on.", - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - # generate tool - tool = PromptMessageTool( - name=FUNCTION_CALLING_EXTRACTOR_NAME, - description="Extract parameters from the natural language text", - parameters=node_data.get_parameter_json_schema(), - ) - - return prompt_messages, [tool] - - def _generate_prompt_engineering_prompt( - self, - data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate prompt engineering prompt. - """ - model_mode = ModelMode(data.model.mode) - - if model_mode == ModelMode.COMPLETION: - return self._generate_prompt_engineering_completion_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - elif model_mode == ModelMode.CHAT: - return self._generate_prompt_engineering_chat_prompt( - node_data=data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - memory=memory, - files=files, - vision_detail=vision_detail, - ) - else: - raise InvalidModelModeError(f"Invalid model mode: {model_mode}") - - def _generate_prompt_engineering_completion_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate completion prompt. - """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, query=query, variable_pool=variable_pool, memory=memory, max_token_limit=rest_token - ) - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={"structure": json.dumps(node_data.get_parameter_json_schema())}, - query="", - files=files, - context="", - memory_config=node_data.memory, - # AdvancedPromptTransform is still typed against TokenBufferMemory. - memory=cast(Any, memory), - model_instance=model_instance, - image_detail_config=vision_detail, - ) - - return prompt_messages - - def _generate_prompt_engineering_chat_prompt( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - memory: PromptMessageMemory | None, - files: Sequence[File], - vision_detail: ImagePromptMessageContent.DETAIL | None = None, - ) -> list[PromptMessage]: - """ - Generate chat prompt. - """ - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query, - variable_pool=variable_pool, - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_engineering_prompt_template( - node_data=node_data, - query=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(node_data.get_parameter_json_schema()), text=query - ), - variable_pool=variable_pool, - memory=memory, - max_token_limit=rest_token, - ) - - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=files, - context="", - memory_config=node_data.memory, - memory=None, - model_instance=model_instance, - image_detail_config=vision_detail, - ) - - # find last user message - last_user_message_idx = -1 - for i, prompt_message in enumerate(prompt_messages): - if prompt_message.role == PromptMessageRole.USER: - last_user_message_idx = i - - # add example messages before last user message - example_messages = [] - for example in CHAT_EXAMPLE: - example_messages.extend( - [ - UserPromptMessage( - content=CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE.format( - structure=json.dumps(example["user"]["json"]), - text=example["user"]["query"], - ) - ), - AssistantPromptMessage( - content=json.dumps(example["assistant"]["json"]), - ), - ] - ) - - prompt_messages = ( - prompt_messages[:last_user_message_idx] + example_messages + prompt_messages[last_user_message_idx:] - ) - - return prompt_messages - - def _validate_result(self, data: ParameterExtractorNodeData, result: dict): - if len(data.parameters) != len(result): - raise InvalidNumberOfParametersError("Invalid number of parameters") - - for parameter in data.parameters: - if parameter.required and parameter.name not in result: - raise RequiredParameterMissingError(f"Parameter {parameter.name} is required") - - param_value = result.get(parameter.name) - if not parameter.type.is_valid(param_value, array_validation=ArrayValidation.ALL): - inferred_type = SegmentType.infer_segment_type(param_value) - raise InvalidValueTypeError( - parameter_name=parameter.name, - expected_type=parameter.type, - actual_type=inferred_type, - value=param_value, - ) - if parameter.type == SegmentType.STRING and parameter.options: - if param_value not in parameter.options: - raise InvalidSelectValueError(f"Invalid `select` value for parameter {parameter.name}") - return result - - @staticmethod - def _transform_number(value: int | float | str | bool) -> int | float | None: - """ - Attempts to transform the input into an integer or float. - - Returns: - int or float: The transformed number if the conversion is successful. - None: If the transformation fails. - - Note: - Boolean values `True` and `False` are converted to integers `1` and `0`, respectively. - This behavior ensures compatibility with existing workflows that may use boolean types as integers. - """ - if isinstance(value, bool): - return int(value) - elif isinstance(value, (int, float)): - return value - elif isinstance(value, str): - if "." in value: - try: - return float(value) - except ValueError: - return None - else: - try: - return int(value) - except ValueError: - return None - else: - return None - - def _transform_result(self, data: ParameterExtractorNodeData, result: dict): - """ - Transform result into standard format. - """ - transformed_result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.name in result: - param_value = result[parameter.name] - # transform value - if parameter.type == SegmentType.NUMBER: - transformed = self._transform_number(param_value) - if transformed is not None: - transformed_result[parameter.name] = transformed - elif parameter.type == SegmentType.BOOLEAN: - if isinstance(result[parameter.name], (bool, int)): - transformed_result[parameter.name] = bool(result[parameter.name]) - # elif isinstance(result[parameter.name], str): - # if result[parameter.name].lower() in ["true", "false"]: - # transformed_result[parameter.name] = bool(result[parameter.name].lower() == "true") - elif parameter.type == SegmentType.STRING: - if isinstance(param_value, str): - transformed_result[parameter.name] = param_value - elif parameter.is_array_type(): - if isinstance(param_value, list): - nested_type = parameter.element_type() - assert nested_type is not None - segment_value = build_segment_with_type(segment_type=SegmentType(parameter.type), value=[]) - transformed_result[parameter.name] = segment_value - for item in param_value: - if nested_type == SegmentType.NUMBER: - transformed = self._transform_number(item) - if transformed is not None: - segment_value.value.append(transformed) - elif nested_type == SegmentType.STRING: - if isinstance(item, str): - segment_value.value.append(item) - elif nested_type == SegmentType.OBJECT: - if isinstance(item, dict): - segment_value.value.append(item) - elif nested_type == SegmentType.BOOLEAN: - if isinstance(item, bool): - segment_value.value.append(item) - - if parameter.name not in transformed_result: - if parameter.type.is_array_type(): - transformed_result[parameter.name] = build_segment_with_type( - segment_type=SegmentType(parameter.type), value=[] - ) - elif parameter.type in (SegmentType.STRING, SegmentType.SECRET): - transformed_result[parameter.name] = "" - elif parameter.type == SegmentType.NUMBER: - transformed_result[parameter.name] = 0 - elif parameter.type == SegmentType.BOOLEAN: - transformed_result[parameter.name] = False - else: - raise AssertionError("this statement should be unreachable.") - - return transformed_result - - def _extract_complete_json_response(self, result: str) -> dict | None: - """ - Extract complete json response. - """ - - # extract json from the text - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - logger.info("extra error: %s", result) - return None - - def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None: - """ - Extract json from tool call. - """ - if not tool_call or not tool_call.function.arguments: - return None - - result = tool_call.function.arguments - # extract json from the arguments - for idx in range(len(result)): - if result[idx] == "{" or result[idx] == "[": - json_str = extract_json(result[idx:]) - if json_str: - with contextlib.suppress(Exception): - return cast(dict, json.loads(json_str)) - - logger.info("extra error: %s", result) - return None - - def _generate_default_result(self, data: ParameterExtractorNodeData): - """ - Generate default result. - """ - result: dict[str, Any] = {} - for parameter in data.parameters: - if parameter.type == "number": - result[parameter.name] = 0 - elif parameter.type == "boolean": - result[parameter.name] = False - elif parameter.type in {"string", "select"}: - result[parameter.name] = "" - - return result - - def _get_function_calling_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ) -> list[ChatModelMessage]: - model_mode = ModelMode(node_data.model.mode) - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT.format(histories=memory_str, instruction=instruction), - ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") - - def _get_prompt_engineering_prompt_template( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ): - model_mode = ModelMode(node_data.model.mode) - input_text = query - memory_str = "" - instruction = variable_pool.convert_template(node_data.instruction or "").text - - if memory and node_data.memory and node_data.memory.window: - memory_str = llm_utils.fetch_memory_text( - memory=memory, max_token_limit=max_token_limit, message_limit=node_data.memory.window.size - ) - if model_mode == ModelMode.CHAT: - system_prompt_messages = ChatModelMessage( - role=PromptMessageRole.SYSTEM, - text=CHAT_GENERATE_JSON_PROMPT.format(histories=memory_str, instructions=instruction), - ) - user_prompt_message = ChatModelMessage(role=PromptMessageRole.USER, text=input_text) - return [system_prompt_messages, user_prompt_message] - elif model_mode == ModelMode.COMPLETION: - return CompletionModelPromptTemplate( - text=COMPLETION_GENERATE_JSON_PROMPT.format( - histories=memory_str, text=input_text, instruction=instruction - ) - .replace("{γγγ", "") - .replace("}γγγ", "") - ) - else: - raise InvalidModelModeError(f"Model mode {model_mode} not support.") - - def _calculate_rest_token( - self, - node_data: ParameterExtractorNodeData, - query: str, - variable_pool: VariablePool, - model_instance: ModelInstance, - context: str | None, - ) -> int: - try: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - except ValueError as exc: - raise ModelSchemaNotFoundError("Model schema not found") from exc - prompt_transform = AdvancedPromptTransform(with_variable_tmpl=True) - - if set(model_schema.features or []) & {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL}: - prompt_template = self._get_function_calling_prompt_template(node_data, query, variable_pool, None, 2000) - else: - prompt_template = self._get_prompt_engineering_prompt_template(node_data, query, variable_pool, None, 2000) - - prompt_messages = prompt_transform.get_prompt( - prompt_template=prompt_template, - inputs={}, - query="", - files=[], - context=context, - memory_config=node_data.memory, - memory=None, - model_instance=model_instance, - ) - rest_tokens = 2000 - - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance) - curr_message_tokens = ( - model_type_instance.get_num_tokens( - model_instance.model_name, model_instance.credentials, prompt_messages - ) - + 1000 - ) # add 1000 to ensure tool call messages - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - @property - def model_instance(self) -> ModelInstance: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ParameterExtractorNodeData, - ) -> Mapping[str, Sequence[str]]: - _ = graph_config # Explicitly mark as unused - variable_mapping: dict[str, Sequence[str]] = {"query": node_data.query} - - if node_data.instruction: - selectors = variable_template_parser.extract_selectors_from_template(node_data.instruction) - for selector in selectors: - variable_mapping[selector.variable] = selector.value_selector - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping diff --git a/api/dify_graph/nodes/parameter_extractor/prompts.py b/api/dify_graph/nodes/parameter_extractor/prompts.py deleted file mode 100644 index 1b29be4418..0000000000 --- a/api/dify_graph/nodes/parameter_extractor/prompts.py +++ /dev/null @@ -1,184 +0,0 @@ -from typing import Any - -FUNCTION_CALLING_EXTRACTOR_NAME = "extract_parameters" - -FUNCTION_CALLING_EXTRACTOR_SYSTEM_PROMPT = f"""You are a helpful assistant tasked with extracting structured information based on specific criteria provided. Follow the guidelines below to ensure consistency and accuracy. -### Task -Always call the `{FUNCTION_CALLING_EXTRACTOR_NAME}` function with the correct parameters. Ensure that the information extraction is contextual and aligns with the provided criteria. -### Memory -Here is the chat history between the human and assistant, provided within tags: - -\x7bhistories\x7d - -### Instructions: -Some additional information is provided below. Always adhere to these instructions as closely as possible: - -\x7binstruction\x7d - -Steps: -1. Review the chat history provided within the tags. -2. Extract the relevant information based on the criteria given, output multiple values if there is multiple relevant information that match the criteria in the given text. -3. Generate a well-formatted output using the defined functions and arguments. -4. Use the `extract_parameter` function to create structured outputs with appropriate parameters. -5. Do not include any XML tags in your output. -### Example -To illustrate, if the task involves extracting a user's name and their request, your function call might look like this: Ensure your output follows a similar structure to examples. -### Final Output -Produce well-formatted function calls in json without XML tags, as shown in the example. -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_USER_TEMPLATE = f"""extract structured information from context inside XML tags by calling the function {FUNCTION_CALLING_EXTRACTOR_NAME} with the correct parameters with structure inside XML tags. - -\x7bcontent\x7d - - - -\x7bstructure\x7d - -""" # noqa: E501 - -FUNCTION_CALLING_EXTRACTOR_EXAMPLE: list[dict[str, Any]] = [ - { - "user": { - "query": "What is the weather today in SF?", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - }, - }, - "required": ["location"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the location parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"location": "San Francisco"}}, - }, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "function": { - "name": FUNCTION_CALLING_EXTRACTOR_NAME, - "parameters": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - }, - "assistant": { - "text": "I need always call the function with the correct parameters." - " in this case, I need to call the function with the food parameter.", - "function_call": {"name": FUNCTION_CALLING_EXTRACTOR_NAME, "parameters": {"food": "apple pie"}}, - }, - }, -] - -COMPLETION_GENERATE_JSON_PROMPT = """### Instructions: -Some extra information are provided below, I should always follow the instructions as possible as I can. - -{instruction} - - -### Extract parameter Workflow -I need to extract the following information from the input text. The tag specifies the 'type', 'description' and 'required' of the information to be extracted. - -{{ structure }} - - -Step 1: Carefully read the input and understand the structure of the expected output. -Step 2: Extract relevant parameters from the provided text based on the name and description of object. -Step 3: Structure the extracted parameters to JSON object as specified in . -Step 4: Ensure that the JSON object is properly formatted and valid. The output should not contain any XML tags. Only the JSON object should be outputted. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Structure -Here is the structure of the expected output, I should always follow the output structure. -{{γγγ - 'properties1': 'relevant text extracted from input', - 'properties2': 'relevant text extracted from input', -}}γγγ - -### Input Text -Inside XML tags, there is a text that I should extract parameters and convert to a JSON object. - -{text} - - -### Answer -I should always output a valid JSON object. Output nothing other than the JSON object. -```JSON -""" # noqa: E501 - -CHAT_GENERATE_JSON_PROMPT = """You should always follow the instructions and output a valid JSON object. -The structure of the JSON object you can found in the instructions. - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - - -### Instructions: -Some extra information are provided below, you should always follow the instructions as possible as you can. - -{instructions} - -""" - -CHAT_GENERATE_JSON_USER_MESSAGE_TEMPLATE = """### Structure -Here is the structure of the JSON object, you should always follow the structure. - -{structure} - - -### Text to be converted to JSON -Inside XML tags, there is a text that you should convert to a JSON object. - -{text} - -""" - -CHAT_EXAMPLE = [ - { - "user": { - "query": "What is the weather today in SF?", - "json": { - "type": "object", - "properties": { - "location": { - "type": "string", - "description": "The location to get the weather information", - "required": True, - } - }, - "required": ["location"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"location": "San Francisco"}}, - }, - { - "user": { - "query": "I want to eat some apple pie.", - "json": { - "type": "object", - "properties": {"food": {"type": "string", "description": "The food to eat", "required": True}}, - "required": ["food"], - }, - }, - "assistant": {"text": "I need to output a valid JSON object.", "json": {"food": "apple pie"}}, - }, -] diff --git a/api/dify_graph/nodes/protocols.py b/api/dify_graph/nodes/protocols.py deleted file mode 100644 index 62d3bcdca1..0000000000 --- a/api/dify_graph/nodes/protocols.py +++ /dev/null @@ -1,46 +0,0 @@ -from collections.abc import Generator -from typing import Any, Protocol - -import httpx - -from dify_graph.file import File -from dify_graph.file.models import ToolFile - - -class HttpClientProtocol(Protocol): - @property - def max_retries_exceeded_error(self) -> type[Exception]: ... - - @property - def request_error(self) -> type[Exception]: ... - - def get(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def head(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def post(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def put(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def delete(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - def patch(self, url: str, max_retries: int = ..., **kwargs: Any) -> httpx.Response: ... - - -class FileManagerProtocol(Protocol): - def download(self, f: File, /) -> bytes: ... - - -class ToolFileManagerProtocol(Protocol): - def create_file_by_raw( - self, - *, - user_id: str, - tenant_id: str, - conversation_id: str | None, - file_binary: bytes, - mimetype: str, - filename: str | None = None, - ) -> Any: ... - - def get_file_generator_by_tool_file_id(self, tool_file_id: str) -> tuple[Generator | None, ToolFile | None]: ... diff --git a/api/dify_graph/nodes/question_classifier/__init__.py b/api/dify_graph/nodes/question_classifier/__init__.py deleted file mode 100644 index 4d06b6bea3..0000000000 --- a/api/dify_graph/nodes/question_classifier/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .entities import QuestionClassifierNodeData -from .question_classifier_node import QuestionClassifierNode - -__all__ = ["QuestionClassifierNode", "QuestionClassifierNodeData"] diff --git a/api/dify_graph/nodes/question_classifier/entities.py b/api/dify_graph/nodes/question_classifier/entities.py deleted file mode 100644 index 0c1601d439..0000000000 --- a/api/dify_graph/nodes/question_classifier/entities.py +++ /dev/null @@ -1,30 +0,0 @@ -from pydantic import BaseModel, Field - -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.llm import ModelConfig, VisionConfig - - -class ClassConfig(BaseModel): - id: str - name: str - - -class QuestionClassifierNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.QUESTION_CLASSIFIER - query_variable_selector: list[str] - model: ModelConfig - classes: list[ClassConfig] - instruction: str | None = None - memory: MemoryConfig | None = None - vision: VisionConfig = Field(default_factory=VisionConfig) - - @property - def structured_output_enabled(self) -> bool: - # NOTE(QuantumGhost): Temporary workaround for issue #20725 - # (https://github.com/langgenius/dify/issues/20725). - # - # The proper fix would be to make `QuestionClassifierNode` inherit - # from `BaseNode` instead of `LLMNode`. - return False diff --git a/api/dify_graph/nodes/question_classifier/exc.py b/api/dify_graph/nodes/question_classifier/exc.py deleted file mode 100644 index 2c6354e2a7..0000000000 --- a/api/dify_graph/nodes/question_classifier/exc.py +++ /dev/null @@ -1,6 +0,0 @@ -class QuestionClassifierNodeError(ValueError): - """Base class for QuestionClassifierNode errors.""" - - -class InvalidModelTypeError(QuestionClassifierNodeError): - """Raised when the model is not a Large Language Model.""" diff --git a/api/dify_graph/nodes/question_classifier/question_classifier_node.py b/api/dify_graph/nodes/question_classifier/question_classifier_node.py deleted file mode 100644 index 928618fdbc..0000000000 --- a/api/dify_graph/nodes/question_classifier/question_classifier_node.py +++ /dev/null @@ -1,399 +0,0 @@ -import json -import re -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from core.model_manager import ModelInstance -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( - BuiltinNodeTypes, - NodeExecutionType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole -from dify_graph.model_runtime.memory import PromptMessageMemory -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.node_events import ModelInvokeCompletedEvent, NodeRunResult -from dify_graph.nodes.base.entities import VariableSelector -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.llm import ( - LLMNode, - LLMNodeChatModelMessage, - LLMNodeCompletionModelPromptTemplate, - llm_utils, -) -from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from libs.json_in_md_parser import parse_and_check_json_markdown - -from .entities import QuestionClassifierNodeData -from .exc import InvalidModelTypeError -from .template_prompts import ( - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1, - QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2, - QUESTION_CLASSIFIER_COMPLETION_PROMPT, - QUESTION_CLASSIFIER_SYSTEM_PROMPT, - QUESTION_CLASSIFIER_USER_PROMPT_1, - QUESTION_CLASSIFIER_USER_PROMPT_2, - QUESTION_CLASSIFIER_USER_PROMPT_3, -) - -if TYPE_CHECKING: - from dify_graph.file.models import File - from dify_graph.runtime import GraphRuntimeState - - -class QuestionClassifierNode(Node[QuestionClassifierNodeData]): - node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER - execution_type = NodeExecutionType.BRANCH - - _file_outputs: list["File"] - _llm_file_saver: LLMFileSaver - _credentials_provider: "CredentialsProvider" - _model_factory: "ModelFactory" - _model_instance: ModelInstance - _memory: PromptMessageMemory | None - _template_renderer: TemplateRenderer - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - credentials_provider: "CredentialsProvider", - model_factory: "ModelFactory", - model_instance: ModelInstance, - http_client: HttpClientProtocol, - template_renderer: TemplateRenderer, - memory: PromptMessageMemory | None = None, - llm_file_saver: LLMFileSaver | None = None, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - # LLM file outputs, used for MultiModal outputs. - self._file_outputs = [] - - self._credentials_provider = credentials_provider - self._model_factory = model_factory - self._model_instance = model_instance - self._memory = memory - self._template_renderer = template_renderer - - if llm_file_saver is None: - dify_ctx = self.require_dify_context() - llm_file_saver = FileSaverImpl( - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - http_client=http_client, - ) - self._llm_file_saver = llm_file_saver - - @classmethod - def version(cls): - return "1" - - def _run(self): - node_data = self.node_data - variable_pool = self.graph_runtime_state.variable_pool - - # extract variables - variable = variable_pool.get(node_data.query_variable_selector) if node_data.query_variable_selector else None - query = variable.value if variable else None - variables = {"query": query} - # fetch model instance - model_instance = self._model_instance - # Resolve variable references in string-typed completion params - model_instance.parameters = llm_utils.resolve_completion_params_variables( - model_instance.parameters, variable_pool - ) - memory = self._memory - # fetch instruction - node_data.instruction = node_data.instruction or "" - node_data.instruction = variable_pool.convert_template(node_data.instruction).text - - files = ( - llm_utils.fetch_files( - variable_pool=variable_pool, - selector=node_data.vision.configs.variable_selector, - ) - if node_data.vision.enabled - else [] - ) - - # fetch prompt messages - rest_token = self._calculate_rest_token( - node_data=node_data, - query=query or "", - model_instance=model_instance, - context="", - ) - prompt_template = self._get_prompt_template( - node_data=node_data, - query=query or "", - memory=memory, - max_token_limit=rest_token, - ) - # Some models (e.g. Gemma, Mistral) force roles alternation (user/assistant/user/assistant...). - # If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt, - # two consecutive user prompts will be generated, causing model's error. - # To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end. - prompt_messages, stop = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - memory=memory, - model_instance=model_instance, - stop=model_instance.stop, - sys_files=files, - vision_enabled=node_data.vision.enabled, - vision_detail=node_data.vision.configs.detail, - variable_pool=variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - - result_text = "" - usage = LLMUsage.empty_usage() - finish_reason = None - - try: - # handle invoke result - generator = LLMNode.invoke_llm( - model_instance=model_instance, - prompt_messages=prompt_messages, - stop=stop, - user_id=self.require_dify_context().user_id, - structured_output_enabled=False, - structured_output=None, - file_saver=self._llm_file_saver, - file_outputs=self._file_outputs, - node_id=self._node_id, - node_type=self.node_type, - ) - - for event in generator: - if isinstance(event, ModelInvokeCompletedEvent): - result_text = event.text - usage = event.usage - finish_reason = event.finish_reason - break - - rendered_classes = [ - c.model_copy(update={"name": variable_pool.convert_template(c.name).text}) for c in node_data.classes - ] - - category_name = rendered_classes[0].name - category_id = rendered_classes[0].id - if "" in result_text: - result_text = re.sub(r"]*>[\s\S]*?", "", result_text, flags=re.IGNORECASE) - result_text_json = parse_and_check_json_markdown(result_text, []) - # result_text_json = json.loads(result_text.strip('```JSON\n')) - if "category_name" in result_text_json and "category_id" in result_text_json: - category_id_result = result_text_json["category_id"] - classes = rendered_classes - classes_map = {class_.id: class_.name for class_ in classes} - category_ids = [_class.id for _class in classes] - if category_id_result in category_ids: - category_name = classes_map[category_id_result] - category_id = category_id_result - process_data = { - "model_mode": node_data.model.mode, - "prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving( - model_mode=node_data.model.mode, prompt_messages=prompt_messages - ), - "usage": jsonable_encoder(usage), - "finish_reason": finish_reason, - "model_provider": model_instance.provider, - "model_name": model_instance.model_name, - } - outputs = { - "class_name": category_name, - "class_id": category_id, - "usage": jsonable_encoder(usage), - } - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=variables, - process_data=process_data, - outputs=outputs, - edge_source_handle=category_id, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - except ValueError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=variables, - error=str(e), - error_type=type(e).__name__, - metadata={ - WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens, - WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price, - WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency, - }, - llm_usage=usage, - ) - - @property - def model_instance(self) -> ModelInstance: - return self._model_instance - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: QuestionClassifierNodeData, - ) -> Mapping[str, Sequence[str]]: - # graph_config is not used in this node type - variable_mapping = {"query": node_data.query_variable_selector} - variable_selectors: list[VariableSelector] = [] - if node_data.instruction: - variable_template_parser = VariableTemplateParser(template=node_data.instruction) - variable_selectors.extend(variable_template_parser.extract_variable_selectors()) - for variable_selector in variable_selectors: - variable_mapping[variable_selector.variable] = list(variable_selector.value_selector) - - variable_mapping = {node_id + "." + key: value for key, value in variable_mapping.items()} - - return variable_mapping - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters (not used in this implementation). - :return: - """ - # filters parameter is not used in this node type - return {"type": "question-classifier", "config": {"instructions": ""}} - - def _calculate_rest_token( - self, - node_data: QuestionClassifierNodeData, - query: str, - model_instance: ModelInstance, - context: str | None, - ) -> int: - model_schema = llm_utils.fetch_model_schema(model_instance=model_instance) - - prompt_template = self._get_prompt_template(node_data, query, None, 2000) - prompt_messages, _ = llm_utils.fetch_prompt_messages( - prompt_template=prompt_template, - sys_query="", - sys_files=[], - context=context, - memory=None, - model_instance=model_instance, - stop=model_instance.stop, - memory_config=node_data.memory, - vision_enabled=False, - vision_detail=node_data.vision.configs.detail, - variable_pool=self.graph_runtime_state.variable_pool, - jinja2_variables=[], - template_renderer=self._template_renderer, - ) - rest_tokens = 2000 - - model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE) - if model_context_tokens: - curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages) - - max_tokens = 0 - for parameter_rule in model_schema.parameter_rules: - if parameter_rule.name == "max_tokens" or ( - parameter_rule.use_template and parameter_rule.use_template == "max_tokens" - ): - max_tokens = ( - model_instance.parameters.get(parameter_rule.name) - or model_instance.parameters.get(parameter_rule.use_template or "") - ) or 0 - - rest_tokens = model_context_tokens - max_tokens - curr_message_tokens - rest_tokens = max(rest_tokens, 0) - - return rest_tokens - - def _get_prompt_template( - self, - node_data: QuestionClassifierNodeData, - query: str, - memory: PromptMessageMemory | None, - max_token_limit: int = 2000, - ): - model_mode = ModelMode(node_data.model.mode) - classes = node_data.classes - categories = [] - for class_ in classes: - category = {"category_id": class_.id, "category_name": class_.name} - categories.append(category) - instruction = node_data.instruction or "" - input_text = query - memory_str = "" - if memory: - memory_str = llm_utils.fetch_memory_text( - memory=memory, - max_token_limit=max_token_limit, - message_limit=node_data.memory.window.size if node_data.memory and node_data.memory.window else None, - ) - prompt_messages: list[LLMNodeChatModelMessage] = [] - if model_mode == ModelMode.CHAT: - system_prompt_messages = LLMNodeChatModelMessage( - role=PromptMessageRole.SYSTEM, text=QUESTION_CLASSIFIER_SYSTEM_PROMPT.format(histories=memory_str) - ) - prompt_messages.append(system_prompt_messages) - user_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_1 - ) - prompt_messages.append(user_prompt_message_1) - assistant_prompt_message_1 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 - ) - prompt_messages.append(assistant_prompt_message_1) - user_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, text=QUESTION_CLASSIFIER_USER_PROMPT_2 - ) - prompt_messages.append(user_prompt_message_2) - assistant_prompt_message_2 = LLMNodeChatModelMessage( - role=PromptMessageRole.ASSISTANT, text=QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 - ) - prompt_messages.append(assistant_prompt_message_2) - user_prompt_message_3 = LLMNodeChatModelMessage( - role=PromptMessageRole.USER, - text=QUESTION_CLASSIFIER_USER_PROMPT_3.format( - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ), - ) - prompt_messages.append(user_prompt_message_3) - return prompt_messages - elif model_mode == ModelMode.COMPLETION: - return LLMNodeCompletionModelPromptTemplate( - text=QUESTION_CLASSIFIER_COMPLETION_PROMPT.format( - histories=memory_str, - input_text=input_text, - categories=json.dumps(categories, ensure_ascii=False), - classification_instructions=instruction, - ) - ) - - else: - raise InvalidModelTypeError(f"Model mode {model_mode} not support.") diff --git a/api/dify_graph/nodes/question_classifier/template_prompts.py b/api/dify_graph/nodes/question_classifier/template_prompts.py deleted file mode 100644 index a615c32383..0000000000 --- a/api/dify_graph/nodes/question_classifier/template_prompts.py +++ /dev/null @@ -1,76 +0,0 @@ -QUESTION_CLASSIFIER_SYSTEM_PROMPT = """ -### Job Description', -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -""" # noqa: E501 - -QUESTION_CLASSIFIER_USER_PROMPT_1 = """ - {"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], - "categories": [{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"},{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"},{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"},{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}], - "classification_instructions": ["classify the text based on the feedback provided by customer"]} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_1 = """ -```json - {"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"], - "category_id": "f5660049-284f-41a7-b301-fd24176a711c", - "category_name": "Customer Service"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_2 = """ - {"input_text": ["bad service, slow to bring the food"], - "categories": [{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"},{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"},{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}], - "classification_instructions": []} -""" # noqa: E501 - -QUESTION_CLASSIFIER_ASSISTANT_PROMPT_2 = """ -```json - {"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"], - "category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f", - "category_name": "Experience"} -``` -""" - -QUESTION_CLASSIFIER_USER_PROMPT_3 = """ - {{"input_text": ["{input_text}"], - "categories": {categories}, - "classification_instructions": ["{classification_instructions}"]}} -""" - -QUESTION_CLASSIFIER_COMPLETION_PROMPT = """ -### Job Description -You are a text classification engine that analyzes text data and assigns categories based on user input or automatically determined categories. -### Task -Your task is to assign one categories ONLY to the input text and only one category may be assigned returned in the output. Additionally, you need to extract the key words from the text that are related to the classification. -### Format -The input text is in the variable input_text. Categories are specified as a category list with two filed category_id and category_name in the variable categories. Classification instructions may be included to improve the classification accuracy. -### Constraint -DO NOT include anything other than the JSON array in your response. -### Example -Here is the chat example between human and assistant, inside XML tags. - -User:{{"input_text": ["I recently had a great experience with your company. The service was prompt and the staff was very friendly."], "categories": [{{"category_id":"f5660049-284f-41a7-b301-fd24176a711c","category_name":"Customer Service"}},{{"category_id":"8d007d06-f2c9-4be5-8ff6-cd4381c13c60","category_name":"Satisfaction"}},{{"category_id":"5fbbbb18-9843-466d-9b8e-b9bfbb9482c8","category_name":"Sales"}},{{"category_id":"23623c75-7184-4a2e-8226-466c2e4631e4","category_name":"Product"}}], "classification_instructions": ["classify the text based on the feedback provided by customer"]}} -Assistant:{{"keywords": ["recently", "great experience", "company", "service", "prompt", "staff", "friendly"],"category_id": "f5660049-284f-41a7-b301-fd24176a711c","category_name": "Customer Service"}} -User:{{"input_text": ["bad service, slow to bring the food"], "categories": [{{"category_id":"80fb86a0-4454-4bf5-924c-f253fdd83c02","category_name":"Food Quality"}},{{"category_id":"f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name":"Experience"}},{{"category_id":"cc771f63-74e7-4c61-882e-3eda9d8ba5d7","category_name":"Price"}}], "classification_instructions": []}} -Assistant:{{"keywords": ["bad service", "slow", "food", "tip", "terrible", "waitresses"],"category_id": "f6ff5bc3-aca0-4e4a-8627-e760d0aca78f","category_name": "Experience"}} - -### Memory -Here are the chat histories between human and assistant, inside XML tags. - -{histories} - -### User Input -{{"input_text" : ["{input_text}"], "categories" : {categories},"classification_instruction" : ["{classification_instructions}"]}} -### Assistant Output -""" # noqa: E501 diff --git a/api/dify_graph/nodes/start/__init__.py b/api/dify_graph/nodes/start/__init__.py deleted file mode 100644 index 5411780423..0000000000 --- a/api/dify_graph/nodes/start/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .start_node import StartNode - -__all__ = ["StartNode"] diff --git a/api/dify_graph/nodes/start/entities.py b/api/dify_graph/nodes/start/entities.py deleted file mode 100644 index 92ebd1a2ec..0000000000 --- a/api/dify_graph/nodes/start/entities.py +++ /dev/null @@ -1,16 +0,0 @@ -from collections.abc import Sequence - -from pydantic import Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.variables.input_entities import VariableEntity - - -class StartNodeData(BaseNodeData): - """ - Start Node Data - """ - - type: NodeType = BuiltinNodeTypes.START - variables: Sequence[VariableEntity] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/start/start_node.py b/api/dify_graph/nodes/start/start_node.py deleted file mode 100644 index 5e6055ea34..0000000000 --- a/api/dify_graph/nodes/start/start_node.py +++ /dev/null @@ -1,63 +0,0 @@ -from typing import Any - -from jsonschema import Draft7Validator, ValidationError - -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.variables.input_entities import VariableEntityType - - -class StartNode(Node[StartNodeData]): - node_type = BuiltinNodeTypes.START - execution_type = NodeExecutionType.ROOT - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - node_inputs = dict(self.graph_runtime_state.variable_pool.user_inputs) - self._validate_and_normalize_json_object_inputs(node_inputs) - system_inputs = self.graph_runtime_state.variable_pool.system_variables.to_dict() - - # TODO: System variables should be directly accessible, no need for special handling - # Set system variables as node outputs. - for var in system_inputs: - node_inputs[SYSTEM_VARIABLE_NODE_ID + "." + var] = system_inputs[var] - outputs = dict(node_inputs) - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=node_inputs, outputs=outputs) - - def _validate_and_normalize_json_object_inputs(self, node_inputs: dict[str, Any]) -> None: - for variable in self.node_data.variables: - if variable.type != VariableEntityType.JSON_OBJECT: - continue - - key = variable.variable - value = node_inputs.get(key) - - if value is None and variable.required: - raise ValueError(f"{key} is required in input form") - - # If no value provided, skip further processing for this key - if not value: - continue - - if not isinstance(value, dict): - raise ValueError(f"JSON object for '{key}' must be an object") - - # Overwrite with normalized dict to ensure downstream consistency - node_inputs[key] = value - - # If schema exists, then validate against it - schema = variable.json_schema - if not schema: - continue - - try: - Draft7Validator(schema).validate(value) - except ValidationError as e: - raise ValueError(f"JSON object for '{key}' does not match schema: {e.message}") diff --git a/api/dify_graph/nodes/template_transform/__init__.py b/api/dify_graph/nodes/template_transform/__init__.py deleted file mode 100644 index 43863b9d59..0000000000 --- a/api/dify_graph/nodes/template_transform/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .template_transform_node import TemplateTransformNode - -__all__ = ["TemplateTransformNode"] diff --git a/api/dify_graph/nodes/template_transform/entities.py b/api/dify_graph/nodes/template_transform/entities.py deleted file mode 100644 index ac29239958..0000000000 --- a/api/dify_graph/nodes/template_transform/entities.py +++ /dev/null @@ -1,13 +0,0 @@ -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.entities import VariableSelector - - -class TemplateTransformNodeData(BaseNodeData): - """ - Template Transform Node Data. - """ - - type: NodeType = BuiltinNodeTypes.TEMPLATE_TRANSFORM - variables: list[VariableSelector] - template: str diff --git a/api/dify_graph/nodes/template_transform/template_renderer.py b/api/dify_graph/nodes/template_transform/template_renderer.py deleted file mode 100644 index 9b679d4497..0000000000 --- a/api/dify_graph/nodes/template_transform/template_renderer.py +++ /dev/null @@ -1,41 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from typing import Any, Protocol - -from dify_graph.nodes.code.code_node import WorkflowCodeExecutor -from dify_graph.nodes.code.entities import CodeLanguage - - -class TemplateRenderError(ValueError): - """Raised when rendering a Jinja2 template fails.""" - - -class Jinja2TemplateRenderer(Protocol): - """Render Jinja2 templates for template transform nodes.""" - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - """Render a Jinja2 template with provided variables.""" - raise NotImplementedError - - -class CodeExecutorJinja2TemplateRenderer(Jinja2TemplateRenderer): - """Adapter that renders Jinja2 templates via CodeExecutor.""" - - _code_executor: WorkflowCodeExecutor - - def __init__(self, code_executor: WorkflowCodeExecutor) -> None: - self._code_executor = code_executor - - def render_template(self, template: str, variables: Mapping[str, Any]) -> str: - try: - result = self._code_executor.execute(language=CodeLanguage.JINJA2, code=template, inputs=variables) - except Exception as exc: - if self._code_executor.is_execution_error(exc): - raise TemplateRenderError(str(exc)) from exc - raise - - rendered = result.get("result") - if not isinstance(rendered, str): - raise TemplateRenderError("Template render result must be a string.") - return rendered diff --git a/api/dify_graph/nodes/template_transform/template_transform_node.py b/api/dify_graph/nodes/template_transform/template_transform_node.py deleted file mode 100644 index dc6fce2b0a..0000000000 --- a/api/dify_graph/nodes/template_transform/template_transform_node.py +++ /dev/null @@ -1,95 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - -DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000 - - -class TemplateTransformNode(Node[TemplateTransformNodeData]): - node_type = BuiltinNodeTypes.TEMPLATE_TRANSFORM - _template_renderer: Jinja2TemplateRenderer - _max_output_length: int - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - template_renderer: Jinja2TemplateRenderer, - max_output_length: int | None = None, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._template_renderer = template_renderer - - if max_output_length is not None and max_output_length <= 0: - raise ValueError("max_output_length must be a positive integer") - self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH - - @classmethod - def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: - """ - Get default config of node. - :param filters: filter by node config parameters. - :return: - """ - return { - "type": "template-transform", - "config": {"variables": [{"variable": "arg1", "value_selector": []}], "template": "{{ arg1 }}"}, - } - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - variables: dict[str, Any] = {} - for variable_selector in self.node_data.variables: - variable_name = variable_selector.variable - value = self.graph_runtime_state.variable_pool.get(variable_selector.value_selector) - variables[variable_name] = value.to_object() if value else None - # Run code - try: - rendered = self._template_renderer.render_template(self.node_data.template, variables) - except TemplateRenderError as e: - return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e)) - - if len(rendered) > self._max_output_length: - return NodeRunResult( - inputs=variables, - status=WorkflowNodeExecutionStatus.FAILED, - error=f"Output length exceeds {self._max_output_length} characters", - ) - - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs=variables, outputs={"output": rendered} - ) - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, *, graph_config: Mapping[str, Any], node_id: str, node_data: TemplateTransformNodeData - ) -> Mapping[str, Sequence[str]]: - return { - node_id + "." + variable_selector.variable: variable_selector.value_selector - for variable_selector in node_data.variables - } diff --git a/api/dify_graph/nodes/tool/__init__.py b/api/dify_graph/nodes/tool/__init__.py deleted file mode 100644 index f4982e655d..0000000000 --- a/api/dify_graph/nodes/tool/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .tool_node import ToolNode - -__all__ = ["ToolNode"] diff --git a/api/dify_graph/nodes/tool/entities.py b/api/dify_graph/nodes/tool/entities.py deleted file mode 100644 index b041ee66fd..0000000000 --- a/api/dify_graph/nodes/tool/entities.py +++ /dev/null @@ -1,87 +0,0 @@ -from typing import Any, Literal, Union - -from pydantic import BaseModel, field_validator -from pydantic_core.core_schema import ValidationInfo - -from core.tools.entities.tool_entities import ToolProviderType -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class ToolEntity(BaseModel): - provider_id: str - provider_type: ToolProviderType - provider_name: str # redundancy - tool_name: str - tool_label: str # redundancy - tool_configurations: dict[str, Any] - credential_id: str | None = None - plugin_unique_identifier: str | None = None # redundancy - - @field_validator("tool_configurations", mode="before") - @classmethod - def validate_tool_configurations(cls, value, values: ValidationInfo): - if not isinstance(value, dict): - raise ValueError("tool_configurations must be a dictionary") - - for key in values.data.get("tool_configurations", {}): - value = values.data.get("tool_configurations", {}).get(key) - if not isinstance(value, str | int | float | bool): - raise ValueError(f"{key} must be a string") - - return value - - -class ToolNodeData(BaseNodeData, ToolEntity): - type: NodeType = BuiltinNodeTypes.TOOL - - class ToolInput(BaseModel): - # TODO: check this type - value: Union[Any, list[str]] - type: Literal["mixed", "variable", "constant"] - - @field_validator("type", mode="before") - @classmethod - def check_type(cls, value, validation_info: ValidationInfo): - typ = value - value = validation_info.data.get("value") - - if value is None: - return typ - - if typ == "mixed" and not isinstance(value, str): - raise ValueError("value must be a string") - elif typ == "variable": - if not isinstance(value, list): - raise ValueError("value must be a list") - for val in value: - if not isinstance(val, str): - raise ValueError("value must be a list of strings") - elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))): - raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}") - return typ - - tool_parameters: dict[str, ToolInput] - # The version of the tool parameter. - # If this value is None, it indicates this is a previous version - # and requires using the legacy parameter parsing rules. - tool_node_version: str | None = None - - @field_validator("tool_parameters", mode="before") - @classmethod - def filter_none_tool_inputs(cls, value): - if not isinstance(value, dict): - return value - - return { - key: tool_input - for key, tool_input in value.items() - if tool_input is not None and cls._has_valid_value(tool_input) - } - - @staticmethod - def _has_valid_value(tool_input): - """Check if the value is valid""" - if isinstance(tool_input, dict): - return tool_input.get("value") is not None - return getattr(tool_input, "value", None) is not None diff --git a/api/dify_graph/nodes/tool/exc.py b/api/dify_graph/nodes/tool/exc.py deleted file mode 100644 index 7212e8bfc0..0000000000 --- a/api/dify_graph/nodes/tool/exc.py +++ /dev/null @@ -1,16 +0,0 @@ -class ToolNodeError(ValueError): - """Base exception for tool node errors.""" - - pass - - -class ToolParameterError(ToolNodeError): - """Exception raised for errors in tool parameters.""" - - pass - - -class ToolFileError(ToolNodeError): - """Exception raised for errors related to tool files.""" - - pass diff --git a/api/dify_graph/nodes/tool/tool_node.py b/api/dify_graph/nodes/tool/tool_node.py deleted file mode 100644 index 598f0da92e..0000000000 --- a/api/dify_graph/nodes/tool/tool_node.py +++ /dev/null @@ -1,524 +0,0 @@ -from collections.abc import Generator, Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler -from core.tools.__base.tool import Tool -from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter -from core.tools.errors import ToolInvokeError -from core.tools.tool_engine import ToolEngine -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import ( - BuiltinNodeTypes, - SystemVariableKey, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.file import File, FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeEventBase, NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser -from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment -from dify_graph.variables.variables import ArrayAnyVariable -from factories import file_factory -from services.tools.builtin_tools_manage_service import BuiltinToolManageService - -from .entities import ToolNodeData -from .exc import ( - ToolFileError, - ToolNodeError, - ToolParameterError, -) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - -class ToolNode(Node[ToolNodeData]): - """ - Tool Node - """ - - node_type = BuiltinNodeTypes.TOOL - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - *, - tool_file_manager_factory: ToolFileManagerProtocol, - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - self._tool_file_manager_factory = tool_file_manager_factory - - @classmethod - def version(cls) -> str: - return "1" - - def populate_start_event(self, event) -> None: - event.provider_id = self.node_data.provider_id - event.provider_type = self.node_data.provider_type - - def _run(self) -> Generator[NodeEventBase, None, None]: - """ - Run the tool node - """ - from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError - - dify_ctx = self.require_dify_context() - - # fetch tool icon - tool_info = { - "provider_type": self.node_data.provider_type.value, - "provider_id": self.node_data.provider_id, - "plugin_unique_identifier": self.node_data.plugin_unique_identifier, - } - - # get tool runtime - try: - from core.tools.tool_manager import ToolManager - - # This is an issue that caused problems before. - # Logically, we shouldn't use the node_data.version field for judgment - # But for backward compatibility with historical data - # this version field judgment is still preserved here. - variable_pool: VariablePool | None = None - if self.node_data.version != "1" or self.node_data.tool_node_version is not None: - variable_pool = self.graph_runtime_state.variable_pool - tool_runtime = ToolManager.get_workflow_tool_runtime( - dify_ctx.tenant_id, - dify_ctx.app_id, - self._node_id, - self.node_data, - dify_ctx.invoke_from, - variable_pool, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs={}, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to get tool runtime: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - # get parameters - tool_parameters = tool_runtime.get_merged_runtime_parameters() or [] - parameters = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - ) - parameters_for_log = self._generate_parameters( - tool_parameters=tool_parameters, - variable_pool=self.graph_runtime_state.variable_pool, - node_data=self.node_data, - for_log=True, - ) - # get conversation id - conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID]) - - try: - message_stream = ToolEngine.generic_invoke( - tool=tool_runtime, - tool_parameters=parameters, - user_id=dify_ctx.user_id, - workflow_tool_callback=DifyWorkflowCallbackHandler(), - workflow_call_depth=self.workflow_call_depth, - app_id=dify_ctx.app_id, - conversation_id=conversation_id.text if conversation_id else None, - ) - except ToolNodeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool: {str(e)}", - error_type=type(e).__name__, - ) - ) - return - - try: - # convert tool messages - _ = yield from self._transform_message( - messages=message_stream, - tool_info=tool_info, - parameters_for_log=parameters_for_log, - user_id=dify_ctx.user_id, - tenant_id=dify_ctx.tenant_id, - node_id=self._node_id, - tool_runtime=tool_runtime, - ) - except ToolInvokeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool {self.node_data.provider_name}: {str(e)}", - error_type=type(e).__name__, - ) - ) - except PluginInvokeError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=e.to_user_friendly_error(plugin_name=self.node_data.provider_name), - error_type=type(e).__name__, - ) - ) - except PluginDaemonClientSideError as e: - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=parameters_for_log, - metadata={WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info}, - error=f"Failed to invoke tool, error: {e.description}", - error_type=type(e).__name__, - ) - ) - - def _generate_parameters( - self, - *, - tool_parameters: Sequence[ToolParameter], - variable_pool: "VariablePool", - node_data: ToolNodeData, - for_log: bool = False, - ) -> dict[str, Any]: - """ - Generate parameters based on the given tool parameters, variable pool, and node data. - - Args: - tool_parameters (Sequence[ToolParameter]): The list of tool parameters. - variable_pool (VariablePool): The variable pool containing the variables. - node_data (ToolNodeData): The data associated with the tool node. - - Returns: - Mapping[str, Any]: A dictionary containing the generated parameters. - - """ - tool_parameters_dictionary = {parameter.name: parameter for parameter in tool_parameters} - - result: dict[str, Any] = {} - for parameter_name in node_data.tool_parameters: - parameter = tool_parameters_dictionary.get(parameter_name) - if not parameter: - result[parameter_name] = None - continue - tool_input = node_data.tool_parameters[parameter_name] - if tool_input.type == "variable": - variable = variable_pool.get(tool_input.value) - if variable is None: - if parameter.required: - raise ToolParameterError(f"Variable {tool_input.value} does not exist") - continue - parameter_value = variable.value - elif tool_input.type in {"mixed", "constant"}: - segment_group = variable_pool.convert_template(str(tool_input.value)) - parameter_value = segment_group.log if for_log else segment_group.text - else: - raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'") - result[parameter_name] = parameter_value - - return result - - def _fetch_files(self, variable_pool: "VariablePool") -> list[File]: - variable = variable_pool.get(["sys", SystemVariableKey.FILES]) - assert isinstance(variable, ArrayAnyVariable | ArrayAnySegment) - return list(variable.value) if variable else [] - - def _transform_message( - self, - messages: Generator[ToolInvokeMessage, None, None], - tool_info: Mapping[str, Any], - parameters_for_log: dict[str, Any], - user_id: str, - tenant_id: str, - node_id: str, - tool_runtime: Tool, - ) -> Generator[NodeEventBase, None, LLMUsage]: - """ - Convert ToolInvokeMessages into tuple[plain_text, files] - """ - # transform message and handle file storage - from core.plugin.impl.plugin import PluginInstaller - - message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages( - messages=messages, - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - ) - - text = "" - files: list[File] = [] - json: list[dict | list] = [] - - variables: dict[str, Any] = {} - - for message in message_stream: - if message.type in { - ToolInvokeMessage.MessageType.IMAGE_LINK, - ToolInvokeMessage.MessageType.BINARY_LINK, - ToolInvokeMessage.MessageType.IMAGE, - }: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - url = message.message.text - if message.meta: - transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE) - else: - transfer_method = FileTransferMethod.TOOL_FILE - - tool_file_id = str(url).split("/")[-1].split(".")[0] - - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not found") - - mapping = { - "tool_file_id": tool_file_id, - "type": file_factory.get_file_type_by_mime_type(tool_file.mimetype), - "transfer_method": transfer_method, - "url": url, - } - file = file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - files.append(file) - elif message.type == ToolInvokeMessage.MessageType.BLOB: - # get tool file id - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - assert message.meta - - tool_file_id = message.message.text.split("/")[-1].split(".")[0] - _, tool_file = self._tool_file_manager_factory.get_file_generator_by_tool_file_id(tool_file_id) - if not tool_file: - raise ToolFileError(f"tool file {tool_file_id} not exists") - - mapping = { - "tool_file_id": tool_file_id, - "transfer_method": FileTransferMethod.TOOL_FILE, - } - - files.append( - file_factory.build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - ) - ) - elif message.type == ToolInvokeMessage.MessageType.TEXT: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - text += message.message.text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=message.message.text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.JSON: - assert isinstance(message.message, ToolInvokeMessage.JsonMessage) - # JSON message handling for tool node - if message.message.json_object: - json.append(message.message.json_object) - elif message.type == ToolInvokeMessage.MessageType.LINK: - assert isinstance(message.message, ToolInvokeMessage.TextMessage) - - # Check if this LINK message is a file link - file_obj = (message.meta or {}).get("file") - if isinstance(file_obj, File): - files.append(file_obj) - stream_text = f"File: {message.message.text}\n" - else: - stream_text = f"Link: {message.message.text}\n" - - text += stream_text - yield StreamChunkEvent( - selector=[node_id, "text"], - chunk=stream_text, - is_final=False, - ) - elif message.type == ToolInvokeMessage.MessageType.VARIABLE: - assert isinstance(message.message, ToolInvokeMessage.VariableMessage) - variable_name = message.message.variable_name - variable_value = message.message.variable_value - if message.message.stream: - if not isinstance(variable_value, str): - raise ToolNodeError("When 'stream' is True, 'variable_value' must be a string.") - if variable_name not in variables: - variables[variable_name] = "" - variables[variable_name] += variable_value - - yield StreamChunkEvent( - selector=[node_id, variable_name], - chunk=variable_value, - is_final=False, - ) - else: - variables[variable_name] = variable_value - elif message.type == ToolInvokeMessage.MessageType.FILE: - assert message.meta is not None - assert isinstance(message.meta, dict) - # Validate that meta contains a 'file' key - if "file" not in message.meta: - raise ToolNodeError("File message is missing 'file' key in meta") - - # Validate that the file is an instance of File - if not isinstance(message.meta["file"], File): - raise ToolNodeError(f"Expected File object but got {type(message.meta['file']).__name__}") - files.append(message.meta["file"]) - elif message.type == ToolInvokeMessage.MessageType.LOG: - assert isinstance(message.message, ToolInvokeMessage.LogMessage) - if message.message.metadata: - icon = tool_info.get("icon", "") - dict_metadata = dict(message.message.metadata) - if dict_metadata.get("provider"): - manager = PluginInstaller() - plugins = manager.list_plugins(tenant_id) - try: - current_plugin = next( - plugin - for plugin in plugins - if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"] - ) - icon = current_plugin.declaration.icon - except StopIteration: - pass - icon_dark = None - try: - builtin_tool = next( - provider - for provider in BuiltinToolManageService.list_builtin_tools( - user_id, - tenant_id, - ) - if provider.name == dict_metadata["provider"] - ) - icon = builtin_tool.icon - icon_dark = builtin_tool.icon_dark - except StopIteration: - pass - - dict_metadata["icon"] = icon - dict_metadata["icon_dark"] = icon_dark - message.message.metadata = dict_metadata - - # Add agent_logs to outputs['json'] to ensure frontend can access thinking process - json_output: list[dict[str, Any] | list[Any]] = [] - - # Step 2: normalize JSON into {"data": [...]}.change json to list[dict] - if json: - json_output.extend(json) - else: - json_output.append({"data": []}) - - # Send final chunk events for all streamed outputs - # Final chunk for text stream - yield StreamChunkEvent( - selector=[self._node_id, "text"], - chunk="", - is_final=True, - ) - - # Final chunks for any streamed variables - for var_name in variables: - yield StreamChunkEvent( - selector=[self._node_id, var_name], - chunk="", - is_final=True, - ) - - usage = self._extract_tool_usage(tool_runtime) - - metadata: dict[WorkflowNodeExecutionMetadataKey, Any] = { - WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info, - } - if isinstance(usage.total_tokens, int) and usage.total_tokens > 0: - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens - metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price - metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency - - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": text, "files": ArrayFileSegment(value=files), "json": json_output, **variables}, - metadata=metadata, - inputs=parameters_for_log, - llm_usage=usage, - ) - ) - - return usage - - @staticmethod - def _extract_tool_usage(tool_runtime: Tool) -> LLMUsage: - # Avoid importing WorkflowTool at module import time; rely on duck typing - # Some runtimes expose `latest_usage`; mocks may synthesize arbitrary attributes. - latest = getattr(tool_runtime, "latest_usage", None) - # Normalize into a concrete LLMUsage. MagicMock returns truthy attribute objects - # for any name, so we must type-check here. - if isinstance(latest, LLMUsage): - return latest - if isinstance(latest, dict): - # Allow dict payloads from external runtimes - return LLMUsage.model_validate(latest) - # Fallback to empty usage when attribute is missing or not a valid payload - return LLMUsage.empty_usage() - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: ToolNodeData, - ) -> Mapping[str, Sequence[str]]: - """ - Extract variable selector to variable mapping - :param graph_config: graph config - :param node_id: node id - :param node_data: node data - :return: - """ - _ = graph_config # Explicitly mark as unused - typed_node_data = node_data - result = {} - for parameter_name in typed_node_data.tool_parameters: - input = typed_node_data.tool_parameters[parameter_name] - match input.type: - case "mixed": - assert isinstance(input.value, str) - selectors = VariableTemplateParser(input.value).extract_variable_selectors() - for selector in selectors: - result[selector.variable] = selector.value_selector - case "variable": - selector_key = ".".join(input.value) - result[f"#{selector_key}#"] = input.value - case "constant": - pass - - result = {node_id + "." + key: value for key, value in result.items()} - - return result - - @property - def retry(self) -> bool: - return self.node_data.retry_config.retry_enabled diff --git a/api/dify_graph/nodes/variable_aggregator/__init__.py b/api/dify_graph/nodes/variable_aggregator/__init__.py deleted file mode 100644 index 0b6bf2a5b6..0000000000 --- a/api/dify_graph/nodes/variable_aggregator/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .variable_aggregator_node import VariableAggregatorNode - -__all__ = ["VariableAggregatorNode"] diff --git a/api/dify_graph/nodes/variable_aggregator/entities.py b/api/dify_graph/nodes/variable_aggregator/entities.py deleted file mode 100644 index 4779ebd9a9..0000000000 --- a/api/dify_graph/nodes/variable_aggregator/entities.py +++ /dev/null @@ -1,35 +0,0 @@ -from pydantic import BaseModel - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.variables.types import SegmentType - - -class AdvancedSettings(BaseModel): - """ - Advanced setting. - """ - - group_enabled: bool - - class Group(BaseModel): - """ - Group. - """ - - output_type: SegmentType - variables: list[list[str]] - group_name: str - - groups: list[Group] - - -class VariableAggregatorNodeData(BaseNodeData): - """ - Variable Aggregator Node Data. - """ - - type: NodeType = BuiltinNodeTypes.VARIABLE_AGGREGATOR - output_type: str - variables: list[list[str]] - advanced_settings: AdvancedSettings | None = None diff --git a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py b/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py deleted file mode 100644 index 7d26de6232..0000000000 --- a/api/dify_graph/nodes/variable_aggregator/variable_aggregator_node.py +++ /dev/null @@ -1,40 +0,0 @@ -from collections.abc import Mapping - -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_aggregator.entities import VariableAggregatorNodeData -from dify_graph.variables.segments import Segment - - -class VariableAggregatorNode(Node[VariableAggregatorNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_AGGREGATOR - - @classmethod - def version(cls) -> str: - return "1" - - def _run(self) -> NodeRunResult: - # Get variables - outputs: dict[str, Segment | Mapping[str, Segment]] = {} - inputs = {} - - if not self.node_data.advanced_settings or not self.node_data.advanced_settings.group_enabled: - for selector in self.node_data.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - if variable is not None: - outputs = {"output": variable} - - inputs = {".".join(selector[1:]): variable.to_object()} - break - else: - for group in self.node_data.advanced_settings.groups: - for selector in group.variables: - variable = self.graph_runtime_state.variable_pool.get(selector) - - if variable is not None: - outputs[group.group_name] = {"output": variable} - inputs[".".join(selector[1:])] = variable.to_object() - break - - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs, inputs=inputs) diff --git a/api/dify_graph/nodes/variable_assigner/__init__.py b/api/dify_graph/nodes/variable_assigner/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/nodes/variable_assigner/common/__init__.py b/api/dify_graph/nodes/variable_assigner/common/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/nodes/variable_assigner/common/exc.py b/api/dify_graph/nodes/variable_assigner/common/exc.py deleted file mode 100644 index f8dbedc290..0000000000 --- a/api/dify_graph/nodes/variable_assigner/common/exc.py +++ /dev/null @@ -1,4 +0,0 @@ -class VariableOperatorNodeError(ValueError): - """Base error type, don't use directly.""" - - pass diff --git a/api/dify_graph/nodes/variable_assigner/common/helpers.py b/api/dify_graph/nodes/variable_assigner/common/helpers.py deleted file mode 100644 index f0b22904a9..0000000000 --- a/api/dify_graph/nodes/variable_assigner/common/helpers.py +++ /dev/null @@ -1,55 +0,0 @@ -from collections.abc import Mapping, MutableMapping, Sequence -from typing import Any, TypeVar - -from pydantic import BaseModel - -from dify_graph.variables import Segment -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.types import SegmentType - -# Use double underscore (`__`) prefix for internal variables -# to minimize risk of collision with user-defined variable names. -_UPDATED_VARIABLES_KEY = "__updated_variables" - - -class UpdatedVariable(BaseModel): - name: str - selector: Sequence[str] - value_type: SegmentType - new_value: Any = None - - -_T = TypeVar("_T", bound=MutableMapping[str, Any]) - - -def variable_to_processed_data(selector: Sequence[str], seg: Segment) -> UpdatedVariable: - if len(selector) < SELECTORS_LENGTH: - raise Exception("selector too short") - _, var_name = selector[:2] - return UpdatedVariable( - name=var_name, - selector=list(selector[:2]), - value_type=seg.value_type, - new_value=seg.value, - ) - - -def set_updated_variables(m: _T, updates: Sequence[UpdatedVariable]) -> _T: - m[_UPDATED_VARIABLES_KEY] = updates - return m - - -def get_updated_variables(m: Mapping[str, Any]) -> Sequence[UpdatedVariable] | None: - updated_values = m.get(_UPDATED_VARIABLES_KEY, None) - if updated_values is None: - return None - result = [] - for items in updated_values: - if isinstance(items, UpdatedVariable): - result.append(items) - elif isinstance(items, dict): - items = UpdatedVariable.model_validate(items) - result.append(items) - else: - raise TypeError(f"Invalid updated variable: {items}, type={type(items)}") - return result diff --git a/api/dify_graph/nodes/variable_assigner/v1/__init__.py b/api/dify_graph/nodes/variable_assigner/v1/__init__.py deleted file mode 100644 index 7eb1428e50..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v1/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/dify_graph/nodes/variable_assigner/v1/node.py b/api/dify_graph/nodes/variable_assigner/v1/node.py deleted file mode 100644 index f9b261b191..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v1/node.py +++ /dev/null @@ -1,109 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase - -from .node_data import VariableAssignerData, WriteMode - -if TYPE_CHECKING: - from dify_graph.runtime import GraphRuntimeState - - -class VariableAssignerNode(Node[VariableAssignerData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - assigned_selector = tuple(self.node_data.assigned_variable_selector) - return assigned_selector in variable_selectors - - @classmethod - def version(cls) -> str: - return "1" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerData, - ) -> Mapping[str, Sequence[str]]: - mapping = {} - assigned_variable_node_id = node_data.assigned_variable_selector[0] - if assigned_variable_node_id == CONVERSATION_VARIABLE_NODE_ID: - selector_key = ".".join(node_data.assigned_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.assigned_variable_selector - - selector_key = ".".join(node_data.input_variable_selector) - key = f"{node_id}.#{selector_key}#" - mapping[key] = node_data.input_variable_selector - return mapping - - def _run(self) -> NodeRunResult: - assigned_variable_selector = self.node_data.assigned_variable_selector - # Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject - original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector) - if not isinstance(original_variable, VariableBase): - raise VariableOperatorNodeError("assigned variable not found") - - match self.node_data.write_mode: - case WriteMode.OVER_WRITE: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_variable = original_variable.model_copy(update={"value": income_value.value}) - - case WriteMode.APPEND: - income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector) - if not income_value: - raise VariableOperatorNodeError("input value not found") - updated_value = original_variable.value + [income_value.value] - updated_variable = original_variable.model_copy(update={"value": updated_value}) - - case WriteMode.CLEAR: - income_value = SegmentType.get_zero_value(original_variable.value_type) - updated_variable = original_variable.model_copy(update={"value": income_value.to_object()}) - - # Over write the variable. - self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable) - - updated_variables = [common_helpers.variable_to_processed_data(assigned_variable_selector, updated_variable)] - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs={ - "value": income_value.to_object(), - }, - # NOTE(QuantumGhost): although only one variable is updated in `v1.VariableAssignerNode`, - # we still set `output_variables` as a list to ensure the schema of output is - # compatible with `v2.VariableAssignerNode`. - process_data=common_helpers.set_updated_variables({}, updated_variables), - outputs={}, - ) diff --git a/api/dify_graph/nodes/variable_assigner/v1/node_data.py b/api/dify_graph/nodes/variable_assigner/v1/node_data.py deleted file mode 100644 index 57acb29535..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v1/node_data.py +++ /dev/null @@ -1,18 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - - -class WriteMode(StrEnum): - OVER_WRITE = "over-write" - APPEND = "append" - CLEAR = "clear" - - -class VariableAssignerData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - assigned_variable_selector: Sequence[str] - write_mode: WriteMode - input_variable_selector: Sequence[str] diff --git a/api/dify_graph/nodes/variable_assigner/v2/__init__.py b/api/dify_graph/nodes/variable_assigner/v2/__init__.py deleted file mode 100644 index 7eb1428e50..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .node import VariableAssignerNode - -__all__ = ["VariableAssignerNode"] diff --git a/api/dify_graph/nodes/variable_assigner/v2/entities.py b/api/dify_graph/nodes/variable_assigner/v2/entities.py deleted file mode 100644 index 2b2bbe85de..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/entities.py +++ /dev/null @@ -1,28 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from pydantic import BaseModel, Field - -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType - -from .enums import InputType, Operation - - -class VariableOperationItem(BaseModel): - variable_selector: Sequence[str] - input_type: InputType - operation: Operation - # NOTE(QuantumGhost): The `value` field serves multiple purposes depending on context: - # - # 1. For CONSTANT input_type: Contains the literal value to be used in the operation. - # 2. For VARIABLE input_type: Initially contains the selector of the source variable. - # 3. During the variable updating procedure: The `value` field is reassigned to hold - # the resolved actual value that will be applied to the target variable. - value: Any = None - - -class VariableAssignerNodeData(BaseNodeData): - type: NodeType = BuiltinNodeTypes.VARIABLE_ASSIGNER - version: str = "2" - items: Sequence[VariableOperationItem] = Field(default_factory=list) diff --git a/api/dify_graph/nodes/variable_assigner/v2/enums.py b/api/dify_graph/nodes/variable_assigner/v2/enums.py deleted file mode 100644 index 291b1208d4..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/enums.py +++ /dev/null @@ -1,20 +0,0 @@ -from enum import StrEnum - - -class Operation(StrEnum): - OVER_WRITE = "over-write" - CLEAR = "clear" - APPEND = "append" - EXTEND = "extend" - SET = "set" - ADD = "+=" - SUBTRACT = "-=" - MULTIPLY = "*=" - DIVIDE = "/=" - REMOVE_FIRST = "remove-first" - REMOVE_LAST = "remove-last" - - -class InputType(StrEnum): - VARIABLE = "variable" - CONSTANT = "constant" diff --git a/api/dify_graph/nodes/variable_assigner/v2/exc.py b/api/dify_graph/nodes/variable_assigner/v2/exc.py deleted file mode 100644 index c50aab8668..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/exc.py +++ /dev/null @@ -1,36 +0,0 @@ -from collections.abc import Sequence -from typing import Any - -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError - -from .enums import InputType, Operation - - -class OperationNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, operation: Operation, variable_type: str): - super().__init__(f"Operation {operation} is not supported for type {variable_type}") - - -class InputTypeNotSupportedError(VariableOperatorNodeError): - def __init__(self, *, input_type: InputType, operation: Operation): - super().__init__(f"Input type {input_type} is not supported for operation {operation}") - - -class VariableNotFoundError(VariableOperatorNodeError): - def __init__(self, *, variable_selector: Sequence[str]): - super().__init__(f"Variable {variable_selector} not found") - - -class InvalidInputValueError(VariableOperatorNodeError): - def __init__(self, *, value: Any): - super().__init__(f"Invalid input value {value}") - - -class ConversationIDNotFoundError(VariableOperatorNodeError): - def __init__(self): - super().__init__("conversation_id not found") - - -class InvalidDataError(VariableOperatorNodeError): - def __init__(self, message: str): - super().__init__(message) diff --git a/api/dify_graph/nodes/variable_assigner/v2/helpers.py b/api/dify_graph/nodes/variable_assigner/v2/helpers.py deleted file mode 100644 index 38c69cbe3c..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/helpers.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import Any - -from dify_graph.variables import SegmentType - -from .enums import Operation - - -def is_operation_supported(*, variable_type: SegmentType, operation: Operation): - match operation: - case Operation.OVER_WRITE | Operation.CLEAR: - return True - case Operation.SET: - return variable_type in { - SegmentType.OBJECT, - SegmentType.STRING, - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - SegmentType.BOOLEAN, - } - case Operation.ADD | Operation.SUBTRACT | Operation.MULTIPLY | Operation.DIVIDE: - # Only number variable can be added, subtracted, multiplied or divided - return variable_type in {SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT} - case Operation.APPEND | Operation.EXTEND | Operation.REMOVE_FIRST | Operation.REMOVE_LAST: - # Only array variable can be appended or extended - # Only array variable can have elements removed - return variable_type.is_array_type() - - -def is_variable_input_supported(*, operation: Operation): - if operation in {Operation.SET, Operation.ADD, Operation.SUBTRACT, Operation.MULTIPLY, Operation.DIVIDE}: - return False - return True - - -def is_constant_input_supported(*, variable_type: SegmentType, operation: Operation): - match variable_type: - case SegmentType.STRING | SegmentType.OBJECT | SegmentType.BOOLEAN: - return operation in {Operation.OVER_WRITE, Operation.SET} - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return operation in { - Operation.OVER_WRITE, - Operation.SET, - Operation.ADD, - Operation.SUBTRACT, - Operation.MULTIPLY, - Operation.DIVIDE, - } - case _: - return False - - -def is_input_value_valid(*, variable_type: SegmentType, operation: Operation, value: Any): - if operation in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST}: - return True - match variable_type: - case SegmentType.STRING: - return isinstance(value, str) - - case SegmentType.BOOLEAN: - return isinstance(value, bool) - - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - if not isinstance(value, int | float): - return False - if operation == Operation.DIVIDE and value == 0: - return False - return True - - case SegmentType.OBJECT: - return isinstance(value, dict) - - # Array & Append - case SegmentType.ARRAY_ANY if operation == Operation.APPEND: - return isinstance(value, str | float | int | dict) - case SegmentType.ARRAY_STRING if operation == Operation.APPEND: - return isinstance(value, str) - case SegmentType.ARRAY_NUMBER if operation == Operation.APPEND: - return isinstance(value, int | float) - case SegmentType.ARRAY_OBJECT if operation == Operation.APPEND: - return isinstance(value, dict) - case SegmentType.ARRAY_BOOLEAN if operation == Operation.APPEND: - return isinstance(value, bool) - - # Array & Extend / Overwrite - case SegmentType.ARRAY_ANY if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str | float | int | dict) for item in value) - case SegmentType.ARRAY_STRING if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, str) for item in value) - case SegmentType.ARRAY_NUMBER if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, int | float) for item in value) - case SegmentType.ARRAY_OBJECT if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, dict) for item in value) - case SegmentType.ARRAY_BOOLEAN if operation in {Operation.EXTEND, Operation.OVER_WRITE}: - return isinstance(value, list) and all(isinstance(item, bool) for item in value) - - case _: - return False diff --git a/api/dify_graph/nodes/variable_assigner/v2/node.py b/api/dify_graph/nodes/variable_assigner/v2/node.py deleted file mode 100644 index f04a6b3b80..0000000000 --- a/api/dify_graph/nodes/variable_assigner/v2/node.py +++ /dev/null @@ -1,246 +0,0 @@ -import json -from collections.abc import Mapping, MutableMapping, Sequence -from typing import TYPE_CHECKING, Any - -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError -from dify_graph.variables import SegmentType, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH - -from . import helpers -from .entities import VariableAssignerNodeData, VariableOperationItem -from .enums import InputType, Operation -from .exc import ( - InputTypeNotSupportedError, - InvalidDataError, - InvalidInputValueError, - OperationNotSupportedError, - VariableNotFoundError, -) - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState - - -def _target_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - selector_node_id = item.variable_selector[0] - if selector_node_id != CONVERSATION_VARIABLE_NODE_ID: - return - selector_str = ".".join(item.variable_selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = item.variable_selector - - -def _source_mapping_from_item(mapping: MutableMapping[str, Sequence[str]], node_id: str, item: VariableOperationItem): - # Keep this in sync with the logic in _run methods... - if item.input_type != InputType.VARIABLE: - return - selector = item.value - if not isinstance(selector, list): - raise InvalidDataError(f"selector is not a list, {node_id=}, {item=}") - if len(selector) < SELECTORS_LENGTH: - raise InvalidDataError(f"selector too short, {node_id=}, {item=}") - selector_str = ".".join(selector) - key = f"{node_id}.#{selector_str}#" - mapping[key] = selector - - -class VariableAssignerNode(Node[VariableAssignerNodeData]): - node_type = BuiltinNodeTypes.VARIABLE_ASSIGNER - - def __init__( - self, - id: str, - config: NodeConfigDict, - graph_init_params: "GraphInitParams", - graph_runtime_state: "GraphRuntimeState", - ): - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: - """ - Check if this Variable Assigner node blocks the output of specific variables. - - Returns True if this node updates any of the requested conversation variables. - """ - # Check each item in this Variable Assigner node - for item in self.node_data.items: - # Convert the item's variable_selector to tuple for comparison - item_selector_tuple = tuple(item.variable_selector) - - # Check if this item updates any of the requested variables - if item_selector_tuple in variable_selectors: - return True - - return False - - @classmethod - def version(cls) -> str: - return "2" - - @classmethod - def _extract_variable_selector_to_variable_mapping( - cls, - *, - graph_config: Mapping[str, Any], - node_id: str, - node_data: VariableAssignerNodeData, - ) -> Mapping[str, Sequence[str]]: - var_mapping: dict[str, Sequence[str]] = {} - for item in node_data.items: - _target_mapping_from_item(var_mapping, node_id, item) - _source_mapping_from_item(var_mapping, node_id, item) - return var_mapping - - def _run(self) -> NodeRunResult: - inputs = self.node_data.model_dump() - process_data: dict[str, Any] = {} - # NOTE: This node has no outputs - updated_variable_selectors: list[Sequence[str]] = [] - - try: - for item in self.node_data.items: - variable = self.graph_runtime_state.variable_pool.get(item.variable_selector) - - # ==================== Validation Part - - # Check if variable exists - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=item.variable_selector) - - # Check if operation is supported - if not helpers.is_operation_supported(variable_type=variable.value_type, operation=item.operation): - raise OperationNotSupportedError(operation=item.operation, variable_type=variable.value_type) - - # Check if variable input is supported - if item.input_type == InputType.VARIABLE and not helpers.is_variable_input_supported( - operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.VARIABLE, operation=item.operation) - - # Check if constant input is supported - if item.input_type == InputType.CONSTANT and not helpers.is_constant_input_supported( - variable_type=variable.value_type, operation=item.operation - ): - raise InputTypeNotSupportedError(input_type=InputType.CONSTANT, operation=item.operation) - - # Get value from variable pool - if ( - item.input_type == InputType.VARIABLE - and item.operation not in {Operation.CLEAR, Operation.REMOVE_FIRST, Operation.REMOVE_LAST} - and item.value is not None - ): - value = self.graph_runtime_state.variable_pool.get(item.value) - if value is None: - raise VariableNotFoundError(variable_selector=item.value) - # Skip if value is NoneSegment - if value.value_type == SegmentType.NONE: - continue - item.value = value.value - - # If set string / bytes / bytearray to object, try convert string to object. - if ( - item.operation == Operation.SET - and variable.value_type == SegmentType.OBJECT - and isinstance(item.value, str | bytes | bytearray) - ): - try: - item.value = json.loads(item.value) - except json.JSONDecodeError: - raise InvalidInputValueError(value=item.value) - - # Check if input value is valid - if not helpers.is_input_value_valid( - variable_type=variable.value_type, operation=item.operation, value=item.value - ): - raise InvalidInputValueError(value=item.value) - - # ==================== Execution Part - - updated_value = self._handle_item( - variable=variable, - operation=item.operation, - value=item.value, - ) - variable = variable.model_copy(update={"value": updated_value}) - self.graph_runtime_state.variable_pool.add(variable.selector, variable) - updated_variable_selectors.append(variable.selector) - except VariableOperatorNodeError as e: - return NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - inputs=inputs, - process_data=process_data, - error=str(e), - ) - - # The `updated_variable_selectors` is a list contains list[str] which not hashable, - # remove the duplicated items first. - updated_variable_selectors = list(set(map(tuple, updated_variable_selectors))) - - for selector in updated_variable_selectors: - variable = self.graph_runtime_state.variable_pool.get(selector) - if not isinstance(variable, VariableBase): - raise VariableNotFoundError(variable_selector=selector) - process_data[variable.name] = variable.value - - updated_variables = [ - common_helpers.variable_to_processed_data(selector, seg) - for selector in updated_variable_selectors - if (seg := self.graph_runtime_state.variable_pool.get(selector)) is not None - ] - - process_data = common_helpers.set_updated_variables(process_data, updated_variables) - return NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - inputs=inputs, - process_data=process_data, - outputs={}, - ) - - def _handle_item( - self, - *, - variable: VariableBase, - operation: Operation, - value: Any, - ): - match operation: - case Operation.OVER_WRITE: - return value - case Operation.CLEAR: - return SegmentType.get_zero_value(variable.value_type).to_object() - case Operation.APPEND: - return variable.value + [value] - case Operation.EXTEND: - return variable.value + value - case Operation.SET: - return value - case Operation.ADD: - return variable.value + value - case Operation.SUBTRACT: - return variable.value - value - case Operation.MULTIPLY: - return variable.value * value - case Operation.DIVIDE: - return variable.value / value - case Operation.REMOVE_FIRST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - # If array is empty, do nothing - if not variable.value: - return variable.value - return variable.value[:-1] diff --git a/api/dify_graph/repositories/__init__.py b/api/dify_graph/repositories/__init__.py deleted file mode 100644 index ef70eb09cc..0000000000 --- a/api/dify_graph/repositories/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -""" -Repository interfaces for data access. - -This package contains repository interfaces that define the contract -for accessing and manipulating data, regardless of the underlying -storage mechanism. -""" - -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository - -__all__ = [ - "OrderConfig", - "WorkflowNodeExecutionRepository", -] diff --git a/api/dify_graph/repositories/human_input_form_repository.py b/api/dify_graph/repositories/human_input_form_repository.py deleted file mode 100644 index 88966831cb..0000000000 --- a/api/dify_graph/repositories/human_input_form_repository.py +++ /dev/null @@ -1,152 +0,0 @@ -import abc -import dataclasses -from collections.abc import Mapping, Sequence -from datetime import datetime -from typing import Any, Protocol - -from dify_graph.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus - - -class HumanInputError(Exception): - pass - - -class FormNotFoundError(HumanInputError): - pass - - -@dataclasses.dataclass -class FormCreateParams: - # app_id is the identifier for the app that the form belongs to. - # It is a string with uuid format. - app_id: str - # None when creating a delivery test form; set for runtime forms. - workflow_execution_id: str | None - - # node_id is the identifier for a specific - # node in the graph. - # - # TODO: for node inside loop / iteration, this would - # cause problems, as a single node may be executed multiple times. - node_id: str - - form_config: HumanInputNodeData - rendered_content: str - # Delivery methods already filtered by runtime context (invoke_from). - delivery_methods: Sequence[DeliveryChannelConfig] - # UI display flag computed by runtime context. - display_in_ui: bool - - # resolved_default_values saves the values for defaults with - # type = VARIABLE. - # - # For type = CONSTANT, the value is not stored inside `resolved_default_values` - resolved_default_values: Mapping[str, Any] - form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME - - # Force creating a console-only recipient for submission in Console. - console_recipient_required: bool = False - console_creator_account_id: str | None = None - # Force creating a backstage recipient for submission in Console. - backstage_recipient_required: bool = False - - -class HumanInputFormEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of the form.""" - pass - - @property - @abc.abstractmethod - def web_app_token(self) -> str | None: - """web_app_token returns the token for submission inside webapp. - - For console/debug execution, this may point to the console submission token - if the form is configured to require console delivery. - """ - - # TODO: what if the users are allowed to add multiple - # webapp delivery? - pass - - @property - @abc.abstractmethod - def recipients(self) -> list["HumanInputFormRecipientEntity"]: ... - - @property - @abc.abstractmethod - def rendered_content(self) -> str: - """Rendered markdown content associated with the form.""" - ... - - @property - @abc.abstractmethod - def selected_action_id(self) -> str | None: - """Identifier of the selected user action if the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def submitted_data(self) -> Mapping[str, Any] | None: - """Submitted form data if available.""" - ... - - @property - @abc.abstractmethod - def submitted(self) -> bool: - """Whether the form has been submitted.""" - ... - - @property - @abc.abstractmethod - def status(self) -> HumanInputFormStatus: - """Current status of the form.""" - ... - - @property - @abc.abstractmethod - def expiration_time(self) -> datetime: - """When the form expires.""" - ... - - -class HumanInputFormRecipientEntity(abc.ABC): - @property - @abc.abstractmethod - def id(self) -> str: - """id returns the identifer of this recipient.""" - ... - - @property - @abc.abstractmethod - def token(self) -> str: - """token returns a random string used to submit form""" - ... - - -class HumanInputFormRepository(Protocol): - """ - Repository interface for HumanInputForm. - - This interface defines the contract for accessing and manipulating - HumanInputForm data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - """Get the form created for a given human input node in a workflow execution. Returns - `None` if the form has not been created yet.""" - ... - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - """ - Create a human input form from form definition. - """ - ... diff --git a/api/dify_graph/repositories/workflow_execution_repository.py b/api/dify_graph/repositories/workflow_execution_repository.py deleted file mode 100644 index ef83f07649..0000000000 --- a/api/dify_graph/repositories/workflow_execution_repository.py +++ /dev/null @@ -1,30 +0,0 @@ -from typing import Protocol - -from dify_graph.entities import WorkflowExecution - - -class WorkflowExecutionRepository(Protocol): - """ - Repository interface for WorkflowExecution. - - This interface defines the contract for accessing and manipulating - WorkflowExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and other implementation details should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowExecution): - """ - Save or update a WorkflowExecution instance. - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The WorkflowExecution instance to save or update - """ - ... diff --git a/api/dify_graph/repositories/workflow_node_execution_repository.py b/api/dify_graph/repositories/workflow_node_execution_repository.py deleted file mode 100644 index e6c1c3e497..0000000000 --- a/api/dify_graph/repositories/workflow_node_execution_repository.py +++ /dev/null @@ -1,73 +0,0 @@ -from collections.abc import Sequence -from dataclasses import dataclass -from typing import Literal, Protocol - -from dify_graph.entities import WorkflowNodeExecution - - -@dataclass -class OrderConfig: - """Configuration for ordering NodeExecution instances.""" - - order_by: list[str] - order_direction: Literal["asc", "desc"] | None = None - - -class WorkflowNodeExecutionRepository(Protocol): - """ - Repository interface for NodeExecution. - - This interface defines the contract for accessing and manipulating - NodeExecution data, regardless of the underlying storage mechanism. - - Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id), - and trigger sources (triggered_from) should be handled at the implementation level, not in - the core interface. This keeps the core domain model clean and independent of specific - application domains or deployment scenarios. - """ - - def save(self, execution: WorkflowNodeExecution): - """ - Save or update a NodeExecution instance. - - This method saves all data on the `WorkflowNodeExecution` object, except for `inputs`, `process_data`, - and `outputs`. Its primary purpose is to persist the status and various metadata, such as execution time - and execution-related details. - - It's main purpose is to save the status and various metadata (execution time, execution metadata etc.) - - This method handles both creating new records and updating existing ones. - The implementation should determine whether to create or update based on - the execution's ID or other identifying fields. - - Args: - execution: The NodeExecution instance to save or update - """ - ... - - def save_execution_data(self, execution: WorkflowNodeExecution): - """Save or update the inputs, process_data, or outputs associated with a specific - node_execution record. - - If any of the inputs, process_data, or outputs are None, those fields will not be updated. - """ - ... - - def get_by_workflow_run( - self, - workflow_run_id: str, - order_config: OrderConfig | None = None, - ) -> Sequence[WorkflowNodeExecution]: - """ - Retrieve all NodeExecution instances for a specific workflow run. - - Args: - workflow_run_id: The workflow run ID - order_config: Optional configuration for ordering results - order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) - order_config.order_direction: Direction to order ("asc" or "desc") - - Returns: - A list of NodeExecution instances - """ - ... diff --git a/api/dify_graph/runtime/__init__.py b/api/dify_graph/runtime/__init__.py deleted file mode 100644 index adca07e59a..0000000000 --- a/api/dify_graph/runtime/__init__.py +++ /dev/null @@ -1,22 +0,0 @@ -from .graph_runtime_state import ( - ChildEngineBuilderNotConfiguredError, - ChildEngineError, - ChildGraphNotFoundError, - GraphRuntimeState, -) -from .graph_runtime_state_protocol import ReadOnlyGraphRuntimeState, ReadOnlyVariablePool -from .read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper, ReadOnlyVariablePoolWrapper -from .variable_pool import VariablePool, VariableValue - -__all__ = [ - "ChildEngineBuilderNotConfiguredError", - "ChildEngineError", - "ChildGraphNotFoundError", - "GraphRuntimeState", - "ReadOnlyGraphRuntimeState", - "ReadOnlyGraphRuntimeStateWrapper", - "ReadOnlyVariablePool", - "ReadOnlyVariablePoolWrapper", - "VariablePool", - "VariableValue", -] diff --git a/api/dify_graph/runtime/graph_runtime_state.py b/api/dify_graph/runtime/graph_runtime_state.py deleted file mode 100644 index 41acc6db35..0000000000 --- a/api/dify_graph/runtime/graph_runtime_state.py +++ /dev/null @@ -1,683 +0,0 @@ -from __future__ import annotations - -import importlib -import json -from collections.abc import Mapping, Sequence -from copy import deepcopy -from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, ClassVar, Protocol - -from pydantic import BaseModel, Field -from pydantic.json import pydantic_encoder - -from dify_graph.enums import NodeExecutionType, NodeState, NodeType -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime.variable_pool import VariablePool - -if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.entities.pause_reason import PauseReason - - -class ReadyQueueProtocol(Protocol): - """Structural interface required from ready queue implementations.""" - - def put(self, item: str) -> None: - """Enqueue the identifier of a node that is ready to run.""" - ... - - def get(self, timeout: float | None = None) -> str: - """Return the next node identifier, blocking until available or timeout expires.""" - ... - - def task_done(self) -> None: - """Signal that the most recently dequeued node has completed processing.""" - ... - - def empty(self) -> bool: - """Return True when the queue contains no pending nodes.""" - ... - - def qsize(self) -> int: - """Approximate the number of pending nodes awaiting execution.""" - ... - - def dumps(self) -> str: - """Serialize the queue contents for persistence.""" - ... - - def loads(self, data: str) -> None: - """Restore the queue contents from a serialized payload.""" - ... - - -class GraphExecutionProtocol(Protocol): - """Structural interface for graph execution aggregate. - - Defines the minimal set of attributes and methods required from a GraphExecution entity - for runtime orchestration and state management. - """ - - workflow_id: str - started: bool - completed: bool - aborted: bool - error: Exception | None - exceptions_count: int - pause_reasons: list[PauseReason] - - def start(self) -> None: - """Transition execution into the running state.""" - ... - - def complete(self) -> None: - """Mark execution as successfully completed.""" - ... - - def abort(self, reason: str) -> None: - """Abort execution in response to an external stop request.""" - ... - - def fail(self, error: Exception) -> None: - """Record an unrecoverable error and end execution.""" - ... - - def dumps(self) -> str: - """Serialize execution state into a JSON payload.""" - ... - - def loads(self, data: str) -> None: - """Restore execution state from a previously serialized payload.""" - ... - - -class ResponseStreamCoordinatorProtocol(Protocol): - """Structural interface for response stream coordinator.""" - - def register(self, response_node_id: str) -> None: - """Register a response node so its outputs can be streamed.""" - ... - - def loads(self, data: str) -> None: - """Restore coordinator state from a serialized payload.""" - ... - - def dumps(self) -> str: - """Serialize coordinator state for persistence.""" - ... - - -class NodeProtocol(Protocol): - """Structural interface for graph nodes.""" - - id: str - state: NodeState - execution_type: NodeExecutionType - node_type: ClassVar[NodeType] - - def blocks_variable_output(self, variable_selectors: set[tuple[str, ...]]) -> bool: ... - - -class EdgeProtocol(Protocol): - id: str - state: NodeState - tail: str - head: str - source_handle: str - - -class GraphProtocol(Protocol): - """Structural interface required from graph instances attached to the runtime state.""" - - nodes: Mapping[str, NodeProtocol] - edges: Mapping[str, EdgeProtocol] - root_node: NodeProtocol - - def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ... - - -class ChildGraphEngineBuilderProtocol(Protocol): - def build_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], - root_node_id: str, - layers: Sequence[object] = (), - ) -> Any: ... - - -class ChildEngineError(ValueError): - """Base error type for child-engine creation failures.""" - - -class ChildEngineBuilderNotConfiguredError(ChildEngineError): - """Raised when child-engine creation is requested without a bound builder.""" - - -class ChildGraphNotFoundError(ChildEngineError): - """Raised when the requested child graph entry point cannot be resolved.""" - - -class _GraphStateSnapshot(BaseModel): - """Serializable graph state snapshot for node/edge states.""" - - nodes: dict[str, NodeState] = Field(default_factory=dict) - edges: dict[str, NodeState] = Field(default_factory=dict) - - -@dataclass(slots=True) -class _GraphRuntimeStateSnapshot: - """Immutable view of a serialized runtime state snapshot.""" - - start_at: float - total_tokens: int - node_run_steps: int - llm_usage: LLMUsage - outputs: dict[str, Any] - variable_pool: VariablePool - has_variable_pool: bool - ready_queue_dump: str | None - graph_execution_dump: str | None - response_coordinator_dump: str | None - paused_nodes: tuple[str, ...] - deferred_nodes: tuple[str, ...] - graph_node_states: dict[str, NodeState] - graph_edge_states: dict[str, NodeState] - - -class GraphRuntimeState: - """Mutable runtime state shared across graph execution components. - - `GraphRuntimeState` encapsulates the runtime state of workflow execution, - including scheduling details, variable values, and timing information. - - Values that are initialized prior to workflow execution and remain constant - throughout the execution should be part of `GraphInitParams` instead. - """ - - def __init__( - self, - *, - variable_pool: VariablePool, - start_at: float, - total_tokens: int = 0, - llm_usage: LLMUsage | None = None, - outputs: dict[str, object] | None = None, - node_run_steps: int = 0, - ready_queue: ReadyQueueProtocol | None = None, - graph_execution: GraphExecutionProtocol | None = None, - response_coordinator: ResponseStreamCoordinatorProtocol | None = None, - graph: GraphProtocol | None = None, - ) -> None: - self._variable_pool = variable_pool - self._start_at = start_at - - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = total_tokens - - self._llm_usage = (llm_usage or LLMUsage.empty_usage()).model_copy() - self._outputs = deepcopy(outputs) if outputs is not None else {} - - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = node_run_steps - - self._graph: GraphProtocol | None = None - - self._ready_queue = ready_queue - self._graph_execution = graph_execution - self._response_coordinator = response_coordinator - self._pending_response_coordinator_dump: str | None = None - self._pending_graph_execution_workflow_id: str | None = None - self._paused_nodes: set[str] = set() - self._deferred_nodes: set[str] = set() - self._child_engine_builder: ChildGraphEngineBuilderProtocol | None = None - - # Node and edges states needed to be restored into - # graph object. - # - # These two fields are non-None only when resuming from a snapshot. - # Once the graph is attached, these two fields will be set to None. - self._pending_graph_node_states: dict[str, NodeState] | None = None - self._pending_graph_edge_states: dict[str, NodeState] | None = None - - if graph is not None: - self.attach_graph(graph) - - # ------------------------------------------------------------------ - # Context binding helpers - # ------------------------------------------------------------------ - def attach_graph(self, graph: GraphProtocol) -> None: - """Attach the materialized graph to the runtime state.""" - if self._graph is not None and self._graph is not graph: - raise ValueError("GraphRuntimeState already attached to a different graph instance") - - self._graph = graph - - if self._response_coordinator is None: - self._response_coordinator = self._build_response_coordinator(graph) - - if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None: - self._response_coordinator.loads(self._pending_response_coordinator_dump) - self._pending_response_coordinator_dump = None - self._apply_pending_graph_state() - - def configure(self, *, graph: GraphProtocol | None = None) -> None: - """Ensure core collaborators are initialized with the provided context.""" - if graph is not None: - self.attach_graph(graph) - - # Ensure collaborators are instantiated - _ = self.ready_queue - _ = self.graph_execution - if self._graph is not None: - _ = self.response_coordinator - - def bind_child_engine_builder(self, builder: ChildGraphEngineBuilderProtocol) -> None: - self._child_engine_builder = builder - - def create_child_engine( - self, - *, - workflow_id: str, - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], - root_node_id: str, - layers: Sequence[object] = (), - ) -> Any: - if self._child_engine_builder is None: - raise ChildEngineBuilderNotConfiguredError("Child engine builder is not configured.") - - return self._child_engine_builder.build_child_engine( - workflow_id=workflow_id, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - graph_config=graph_config, - root_node_id=root_node_id, - layers=layers, - ) - - # ------------------------------------------------------------------ - # Primary collaborators - # ------------------------------------------------------------------ - @property - def variable_pool(self) -> VariablePool: - return self._variable_pool - - @property - def ready_queue(self) -> ReadyQueueProtocol: - if self._ready_queue is None: - self._ready_queue = self._build_ready_queue() - return self._ready_queue - - @property - def graph_execution(self) -> GraphExecutionProtocol: - if self._graph_execution is None: - self._graph_execution = self._build_graph_execution() - return self._graph_execution - - @property - def response_coordinator(self) -> ResponseStreamCoordinatorProtocol: - if self._response_coordinator is None: - if self._graph is None: - raise ValueError("Graph must be attached before accessing response coordinator") - self._response_coordinator = self._build_response_coordinator(self._graph) - return self._response_coordinator - - # ------------------------------------------------------------------ - # Scalar state - # ------------------------------------------------------------------ - @property - def start_at(self) -> float: - return self._start_at - - @start_at.setter - def start_at(self, value: float) -> None: - self._start_at = value - - @property - def total_tokens(self) -> int: - return self._total_tokens - - @total_tokens.setter - def total_tokens(self, value: int) -> None: - if value < 0: - raise ValueError("total_tokens must be non-negative") - self._total_tokens = value - - @property - def llm_usage(self) -> LLMUsage: - return self._llm_usage.model_copy() - - @llm_usage.setter - def llm_usage(self, value: LLMUsage) -> None: - self._llm_usage = value.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._outputs) - - @outputs.setter - def outputs(self, value: dict[str, Any]) -> None: - self._outputs = deepcopy(value) - - def set_output(self, key: str, value: object) -> None: - self._outputs[key] = deepcopy(value) - - def get_output(self, key: str, default: object = None) -> object: - return deepcopy(self._outputs.get(key, default)) - - def update_outputs(self, updates: dict[str, object]) -> None: - for key, value in updates.items(): - self._outputs[key] = deepcopy(value) - - @property - def node_run_steps(self) -> int: - return self._node_run_steps - - @node_run_steps.setter - def node_run_steps(self, value: int) -> None: - if value < 0: - raise ValueError("node_run_steps must be non-negative") - self._node_run_steps = value - - def increment_node_run_steps(self) -> None: - self._node_run_steps += 1 - - def add_tokens(self, tokens: int) -> None: - if tokens < 0: - raise ValueError("tokens must be non-negative") - self._total_tokens += tokens - - # ------------------------------------------------------------------ - # Serialization - # ------------------------------------------------------------------ - def dumps(self) -> str: - """Serialize runtime state into a JSON string.""" - - snapshot: dict[str, Any] = { - "version": "1.0", - "start_at": self._start_at, - "total_tokens": self._total_tokens, - "node_run_steps": self._node_run_steps, - "llm_usage": self._llm_usage.model_dump(mode="json"), - "outputs": self.outputs, - "variable_pool": self.variable_pool.model_dump(mode="json"), - "ready_queue": self.ready_queue.dumps(), - "graph_execution": self.graph_execution.dumps(), - "paused_nodes": list(self._paused_nodes), - "deferred_nodes": list(self._deferred_nodes), - } - - graph_state = self._snapshot_graph_state() - if graph_state is not None: - snapshot["graph_state"] = graph_state - - if self._response_coordinator is not None and self._graph is not None: - snapshot["response_coordinator"] = self._response_coordinator.dumps() - - return json.dumps(snapshot, default=pydantic_encoder) - - @classmethod - def from_snapshot(cls, data: str | Mapping[str, Any]) -> GraphRuntimeState: - """Restore runtime state from a serialized snapshot.""" - - snapshot = cls._parse_snapshot_payload(data) - - state = cls( - variable_pool=snapshot.variable_pool, - start_at=snapshot.start_at, - total_tokens=snapshot.total_tokens, - llm_usage=snapshot.llm_usage, - outputs=snapshot.outputs, - node_run_steps=snapshot.node_run_steps, - ) - state._apply_snapshot(snapshot) - return state - - def loads(self, data: str | Mapping[str, Any]) -> None: - """Restore runtime state from a serialized snapshot (legacy API).""" - - snapshot = self._parse_snapshot_payload(data) - self._apply_snapshot(snapshot) - - def register_paused_node(self, node_id: str) -> None: - """Record a node that should resume when execution is continued.""" - - self._paused_nodes.add(node_id) - - def get_paused_nodes(self) -> list[str]: - """Retrieve the list of paused nodes without mutating internal state.""" - - return list(self._paused_nodes) - - def consume_paused_nodes(self) -> list[str]: - """Retrieve and clear the list of paused nodes awaiting resume.""" - - nodes = list(self._paused_nodes) - self._paused_nodes.clear() - return nodes - - def register_deferred_node(self, node_id: str) -> None: - """Record a node that became ready during pause and should resume later.""" - - self._deferred_nodes.add(node_id) - - def get_deferred_nodes(self) -> list[str]: - """Retrieve deferred nodes without mutating internal state.""" - - return list(self._deferred_nodes) - - def consume_deferred_nodes(self) -> list[str]: - """Retrieve and clear deferred nodes awaiting resume.""" - - nodes = list(self._deferred_nodes) - self._deferred_nodes.clear() - return nodes - - # ------------------------------------------------------------------ - # Builders - # ------------------------------------------------------------------ - def _build_ready_queue(self) -> ReadyQueueProtocol: - # Import lazily to avoid breaching architecture boundaries enforced by import-linter. - module = importlib.import_module("dify_graph.graph_engine.ready_queue") - in_memory_cls = module.InMemoryReadyQueue - return in_memory_cls() - - def _build_graph_execution(self) -> GraphExecutionProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("dify_graph.graph_engine.domain.graph_execution") - graph_execution_cls = module.GraphExecution - workflow_id = self._pending_graph_execution_workflow_id or "" - self._pending_graph_execution_workflow_id = None - return graph_execution_cls(workflow_id=workflow_id) # type: ignore[invalid-return-type] - - def _build_response_coordinator(self, graph: GraphProtocol) -> ResponseStreamCoordinatorProtocol: - # Lazily import to keep the runtime domain decoupled from graph_engine modules. - module = importlib.import_module("dify_graph.graph_engine.response_coordinator") - coordinator_cls = module.ResponseStreamCoordinator - return coordinator_cls(variable_pool=self.variable_pool, graph=graph) - - # ------------------------------------------------------------------ - # Snapshot helpers - # ------------------------------------------------------------------ - @classmethod - def _parse_snapshot_payload(cls, data: str | Mapping[str, Any]) -> _GraphRuntimeStateSnapshot: - payload: dict[str, Any] - if isinstance(data, str): - payload = json.loads(data) - else: - payload = dict(data) - - version = payload.get("version") - if version != "1.0": - raise ValueError(f"Unsupported GraphRuntimeState snapshot version: {version}") - - start_at = float(payload.get("start_at", 0.0)) - - total_tokens = int(payload.get("total_tokens", 0)) - if total_tokens < 0: - raise ValueError("total_tokens must be non-negative") - - node_run_steps = int(payload.get("node_run_steps", 0)) - if node_run_steps < 0: - raise ValueError("node_run_steps must be non-negative") - - llm_usage_payload = payload.get("llm_usage", {}) - llm_usage = LLMUsage.model_validate(llm_usage_payload) - - outputs_payload = deepcopy(payload.get("outputs", {})) - - variable_pool_payload = payload.get("variable_pool") - has_variable_pool = variable_pool_payload is not None - variable_pool = VariablePool.model_validate(variable_pool_payload) if has_variable_pool else VariablePool() - - ready_queue_payload = payload.get("ready_queue") - graph_execution_payload = payload.get("graph_execution") - response_payload = payload.get("response_coordinator") - paused_nodes_payload = payload.get("paused_nodes", []) - deferred_nodes_payload = payload.get("deferred_nodes", []) - graph_state_payload = payload.get("graph_state", {}) or {} - graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes") - graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges") - - return _GraphRuntimeStateSnapshot( - start_at=start_at, - total_tokens=total_tokens, - node_run_steps=node_run_steps, - llm_usage=llm_usage, - outputs=outputs_payload, - variable_pool=variable_pool, - has_variable_pool=has_variable_pool, - ready_queue_dump=ready_queue_payload, - graph_execution_dump=graph_execution_payload, - response_coordinator_dump=response_payload, - paused_nodes=tuple(map(str, paused_nodes_payload)), - deferred_nodes=tuple(map(str, deferred_nodes_payload)), - graph_node_states=graph_node_states, - graph_edge_states=graph_edge_states, - ) - - def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None: - self._start_at = snapshot.start_at - self._total_tokens = snapshot.total_tokens - self._node_run_steps = snapshot.node_run_steps - self._llm_usage = snapshot.llm_usage.model_copy() - self._outputs = deepcopy(snapshot.outputs) - if snapshot.has_variable_pool or self._variable_pool is None: - self._variable_pool = snapshot.variable_pool - - self._restore_ready_queue(snapshot.ready_queue_dump) - self._restore_graph_execution(snapshot.graph_execution_dump) - self._restore_response_coordinator(snapshot.response_coordinator_dump) - self._paused_nodes = set(snapshot.paused_nodes) - self._deferred_nodes = set(snapshot.deferred_nodes) - self._pending_graph_node_states = snapshot.graph_node_states or None - self._pending_graph_edge_states = snapshot.graph_edge_states or None - self._apply_pending_graph_state() - - def _restore_ready_queue(self, payload: str | None) -> None: - if payload is not None: - self._ready_queue = self._build_ready_queue() - self._ready_queue.loads(payload) - else: - self._ready_queue = None - - def _restore_graph_execution(self, payload: str | None) -> None: - self._graph_execution = None - self._pending_graph_execution_workflow_id = None - - if payload is None: - return - - try: - execution_payload = json.loads(payload) - self._pending_graph_execution_workflow_id = execution_payload.get("workflow_id") - except (json.JSONDecodeError, TypeError, AttributeError): - self._pending_graph_execution_workflow_id = None - - self.graph_execution.loads(payload) - - def _restore_response_coordinator(self, payload: str | None) -> None: - if payload is None: - self._pending_response_coordinator_dump = None - self._response_coordinator = None - return - - if self._graph is not None: - self.response_coordinator.loads(payload) - self._pending_response_coordinator_dump = None - return - - self._pending_response_coordinator_dump = payload - self._response_coordinator = None - - def _snapshot_graph_state(self) -> _GraphStateSnapshot: - graph = self._graph - if graph is None: - if self._pending_graph_node_states is None and self._pending_graph_edge_states is None: - return _GraphStateSnapshot() - return _GraphStateSnapshot( - nodes=self._pending_graph_node_states or {}, - edges=self._pending_graph_edge_states or {}, - ) - - nodes = graph.nodes - edges = graph.edges - if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping): - return _GraphStateSnapshot() - - node_states = {} - for node_id, node in nodes.items(): - if not isinstance(node_id, str): - continue - node_states[node_id] = node.state - - edge_states = {} - for edge_id, edge in edges.items(): - if not isinstance(edge_id, str): - continue - edge_states[edge_id] = edge.state - - return _GraphStateSnapshot(nodes=node_states, edges=edge_states) - - def _apply_pending_graph_state(self) -> None: - if self._graph is None: - return - if self._pending_graph_node_states: - for node_id, state in self._pending_graph_node_states.items(): - node = self._graph.nodes.get(node_id) - if node is None: - continue - node.state = state - if self._pending_graph_edge_states: - for edge_id, state in self._pending_graph_edge_states.items(): - edge = self._graph.edges.get(edge_id) - if edge is None: - continue - edge.state = state - - self._pending_graph_node_states = None - self._pending_graph_edge_states = None - - -def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]: - if not isinstance(payload, Mapping): - return {} - raw_map = payload.get(key, {}) - if not isinstance(raw_map, Mapping): - return {} - result: dict[str, NodeState] = {} - for node_id, raw_state in raw_map.items(): - if not isinstance(node_id, str): - continue - try: - result[node_id] = NodeState(str(raw_state)) - except ValueError: - continue - return result diff --git a/api/dify_graph/runtime/graph_runtime_state_protocol.py b/api/dify_graph/runtime/graph_runtime_state_protocol.py deleted file mode 100644 index 7e55ece3f1..0000000000 --- a/api/dify_graph/runtime/graph_runtime_state_protocol.py +++ /dev/null @@ -1,83 +0,0 @@ -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView -from dify_graph.variables.segments import Segment - - -class ReadOnlyVariablePool(Protocol): - """Read-only interface for VariablePool.""" - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Get a variable value (read-only).""" - ... - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Get all variables for a node (read-only).""" - ... - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Get all variables stored under a given node prefix (read-only).""" - ... - - -class ReadOnlyGraphRuntimeState(Protocol): - """ - Read-only view of GraphRuntimeState for layers. - - This protocol defines a read-only interface that prevents layers from - modifying the graph runtime state while still allowing observation. - All methods return defensive copies to ensure immutability. - """ - - @property - def system_variable(self) -> SystemVariableReadOnlyView: ... - - @property - def variable_pool(self) -> ReadOnlyVariablePool: - """Get read-only access to the variable pool.""" - ... - - @property - def start_at(self) -> float: - """Get the start time (read-only).""" - ... - - @property - def total_tokens(self) -> int: - """Get the total tokens count (read-only).""" - ... - - @property - def llm_usage(self) -> LLMUsage: - """Get a copy of LLM usage info (read-only).""" - ... - - @property - def outputs(self) -> dict[str, Any]: - """Get a defensive copy of outputs (read-only).""" - ... - - @property - def node_run_steps(self) -> int: - """Get the node run steps count (read-only).""" - ... - - @property - def ready_queue_size(self) -> int: - """Get the number of nodes currently in the ready queue.""" - ... - - @property - def exceptions_count(self) -> int: - """Get the number of node execution exceptions recorded.""" - ... - - def get_output(self, key: str, default: Any = None) -> Any: - """Get a single output value (returns a copy).""" - ... - - def dumps(self) -> str: - """Serialize the runtime state into a JSON snapshot (read-only).""" - ... diff --git a/api/dify_graph/runtime/read_only_wrappers.py b/api/dify_graph/runtime/read_only_wrappers.py deleted file mode 100644 index ca06d88c3d..0000000000 --- a/api/dify_graph/runtime/read_only_wrappers.py +++ /dev/null @@ -1,87 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Any - -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.system_variable import SystemVariableReadOnlyView -from dify_graph.variables.segments import Segment - -from .graph_runtime_state import GraphRuntimeState -from .variable_pool import VariablePool - - -class ReadOnlyVariablePoolWrapper: - """Provide defensive, read-only access to ``VariablePool``.""" - - def __init__(self, variable_pool: VariablePool) -> None: - self._variable_pool = variable_pool - - def get(self, selector: Sequence[str], /) -> Segment | None: - """Return a copy of a variable value if present.""" - value = self._variable_pool.get(selector) - return deepcopy(value) if value is not None else None - - def get_all_by_node(self, node_id: str) -> Mapping[str, object]: - """Return a copy of all variables for the specified node.""" - variables: dict[str, object] = {} - if node_id in self._variable_pool.variable_dictionary: - for key, variable in self._variable_pool.variable_dictionary[node_id].items(): - variables[key] = deepcopy(variable.value) - return variables - - def get_by_prefix(self, prefix: str) -> Mapping[str, object]: - """Return a copy of all variables stored under the given prefix.""" - return self._variable_pool.get_by_prefix(prefix) - - -class ReadOnlyGraphRuntimeStateWrapper: - """Expose a defensive, read-only view of ``GraphRuntimeState``.""" - - def __init__(self, state: GraphRuntimeState) -> None: - self._state = state - self._variable_pool_wrapper = ReadOnlyVariablePoolWrapper(state.variable_pool) - - @property - def system_variable(self) -> SystemVariableReadOnlyView: - return self._state.variable_pool.system_variables.as_view() - - @property - def variable_pool(self) -> ReadOnlyVariablePoolWrapper: - return self._variable_pool_wrapper - - @property - def start_at(self) -> float: - return self._state.start_at - - @property - def total_tokens(self) -> int: - return self._state.total_tokens - - @property - def llm_usage(self) -> LLMUsage: - return self._state.llm_usage.model_copy() - - @property - def outputs(self) -> dict[str, Any]: - return deepcopy(self._state.outputs) - - @property - def node_run_steps(self) -> int: - return self._state.node_run_steps - - @property - def ready_queue_size(self) -> int: - return self._state.ready_queue.qsize() - - @property - def exceptions_count(self) -> int: - return self._state.graph_execution.exceptions_count - - def get_output(self, key: str, default: Any = None) -> Any: - return self._state.get_output(key, default) - - def dumps(self) -> str: - """Serialize the underlying runtime state for external persistence.""" - return self._state.dumps() diff --git a/api/dify_graph/runtime/variable_pool.py b/api/dify_graph/runtime/variable_pool.py deleted file mode 100644 index e3ef6a2897..0000000000 --- a/api/dify_graph/runtime/variable_pool.py +++ /dev/null @@ -1,280 +0,0 @@ -from __future__ import annotations - -import re -from collections import defaultdict -from collections.abc import Mapping, Sequence -from copy import deepcopy -from typing import Annotated, Any, Union, cast - -from pydantic import BaseModel, Field - -from dify_graph.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, - RAG_PIPELINE_VARIABLE_NODE_ID, - SYSTEM_VARIABLE_NODE_ID, -) -from dify_graph.file import File, FileAttribute, file_manager -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import Segment, SegmentGroup, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.segments import FileSegment, ObjectSegment -from dify_graph.variables.variables import RAGPipelineVariableInput, Variable -from factories import variable_factory - -VariableValue = Union[str, int, float, dict[str, object], list[object], File] - -VARIABLE_PATTERN = re.compile(r"\{\{#([a-zA-Z0-9_]{1,50}(?:\.[a-zA-Z_][a-zA-Z0-9_]{0,29}){1,10})#\}\}") - - -class VariablePool(BaseModel): - # Variable dictionary is a dictionary for looking up variables by their selector. - # The first element of the selector is the node id, it's the first-level key in the dictionary. - # Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the - # elements of the selector except the first one. - variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field( - description="Variables mapping", - default=defaultdict(dict), - ) - - # The `user_inputs` is used only when constructing the inputs for the `StartNode`. It's not used elsewhere. - user_inputs: Mapping[str, Any] = Field( - description="User inputs", - default_factory=dict, - ) - system_variables: SystemVariable = Field( - description="System variables", - default_factory=SystemVariable.default, - ) - environment_variables: Sequence[Variable] = Field( - description="Environment variables.", - default_factory=list[Variable], - ) - conversation_variables: Sequence[Variable] = Field( - description="Conversation variables.", - default_factory=list[Variable], - ) - rag_pipeline_variables: list[RAGPipelineVariableInput] = Field( - description="RAG pipeline variables.", - default_factory=list, - ) - - def model_post_init(self, context: Any, /): - # Create a mapping from field names to SystemVariableKey enum values - self._add_system_variables(self.system_variables) - # Add environment variables to the variable pool - for var in self.environment_variables: - self.add((ENVIRONMENT_VARIABLE_NODE_ID, var.name), var) - # Add conversation variables to the variable pool. When restoring from a serialized - # snapshot, `variable_dictionary` already carries the latest runtime values. - # In that case, keep existing entries instead of overwriting them with the - # bootstrap list. - for var in self.conversation_variables: - selector = (CONVERSATION_VARIABLE_NODE_ID, var.name) - if self._has(selector): - continue - self.add(selector, var) - # Add rag pipeline variables to the variable pool - if self.rag_pipeline_variables: - rag_pipeline_variables_map: defaultdict[Any, dict[Any, Any]] = defaultdict(dict) - for rag_var in self.rag_pipeline_variables: - node_id = rag_var.variable.belong_to_node_id - key = rag_var.variable.variable - value = rag_var.value - rag_pipeline_variables_map[node_id][key] = value - for key, value in rag_pipeline_variables_map.items(): - self.add((RAG_PIPELINE_VARIABLE_NODE_ID, key), value) - - def add(self, selector: Sequence[str], value: Any, /): - """ - Add a variable to the variable pool. - - This method accepts a selector path and a value, converting the value - to a Variable object if necessary before storing it in the pool. - - Args: - selector: A two-element sequence containing [node_id, variable_name]. - The selector must have exactly 2 elements to be valid. - value: The value to store. Can be a Variable, Segment, or any value - that can be converted to a Segment (str, int, float, dict, list, File). - - Raises: - ValueError: If selector length is not exactly 2 elements. - - Note: - While non-Segment values are currently accepted and automatically - converted, it's recommended to pass Segment or Variable objects directly. - """ - if len(selector) != SELECTORS_LENGTH: - raise ValueError( - f"Invalid selector: expected {SELECTORS_LENGTH} elements (node_id, variable_name), " - f"got {len(selector)} elements" - ) - - if isinstance(value, VariableBase): - variable = value - elif isinstance(value, Segment): - variable = variable_factory.segment_to_variable(segment=value, selector=selector) - else: - segment = variable_factory.build_segment(value) - variable = variable_factory.segment_to_variable(segment=segment, selector=selector) - - node_id, name = self._selector_to_keys(selector) - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - self.variable_dictionary[node_id][name] = cast(Variable, variable) - - @classmethod - def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]: - return selector[0], selector[1] - - def _has(self, selector: Sequence[str]) -> bool: - node_id, name = self._selector_to_keys(selector) - if node_id not in self.variable_dictionary: - return False - if name not in self.variable_dictionary[node_id]: - return False - return True - - def get(self, selector: Sequence[str], /) -> Segment | None: - """ - Retrieve a variable's value from the pool as a Segment. - - This method supports both simple selectors [node_id, variable_name] and - extended selectors that include attribute access for FileSegment and - ObjectSegment types. - - Args: - selector: A sequence with at least 2 elements: - - [node_id, variable_name]: Returns the full segment - - [node_id, variable_name, attr, ...]: Returns a nested value - from FileSegment (e.g., 'url', 'name') or ObjectSegment - - Returns: - The Segment associated with the selector, or None if not found. - Returns None if selector has fewer than 2 elements. - - Raises: - ValueError: If attempting to access an invalid FileAttribute. - """ - if len(selector) < SELECTORS_LENGTH: - return None - - node_id, name = self._selector_to_keys(selector) - node_map = self.variable_dictionary.get(node_id) - if node_map is None: - return None - - segment: Segment | None = node_map.get(name) - - if segment is None: - return None - - if len(selector) == 2: - return segment - - if isinstance(segment, FileSegment): - attr = selector[2] - # Python support `attr in FileAttribute` after 3.12 - if attr not in {item.value for item in FileAttribute}: - return None - attr = FileAttribute(attr) - attr_value = file_manager.get_attr(file=segment.value, attr=attr) - return variable_factory.build_segment(attr_value) - - # Navigate through nested attributes - result: Any = segment - for attr in selector[2:]: - result = self._extract_value(result) - result = self._get_nested_attribute(result, attr) - if result is None: - return None - - # Return result as Segment - return result if isinstance(result, Segment) else variable_factory.build_segment(result) - - def _extract_value(self, obj: Any): - """Extract the actual value from an ObjectSegment.""" - return obj.value if isinstance(obj, ObjectSegment) else obj - - def _get_nested_attribute(self, obj: Mapping[str, Any], attr: str) -> Segment | None: - """ - Get a nested attribute from a dictionary-like object. - - Args: - obj: The dictionary-like object to search. - attr: The key to look up. - - Returns: - Segment | None: - The corresponding Segment built from the attribute value if the key exists, - otherwise None. - """ - if not isinstance(obj, dict) or attr not in obj: - return None - return variable_factory.build_segment(obj.get(attr)) - - def remove(self, selector: Sequence[str], /): - """ - Remove variables from the variable pool based on the given selector. - - Args: - selector (Sequence[str]): A sequence of strings representing the selector. - - Returns: - None - """ - if not selector: - return - if len(selector) == 1: - self.variable_dictionary[selector[0]] = {} - return - key, hash_key = self._selector_to_keys(selector) - self.variable_dictionary[key].pop(hash_key, None) - - def convert_template(self, template: str, /): - parts = VARIABLE_PATTERN.split(template) - segments: list[Segment] = [] - for part in filter(lambda x: x, parts): - if "." in part and (variable := self.get(part.split("."))): - segments.append(variable) - else: - segments.append(variable_factory.build_segment(part)) - return SegmentGroup(value=segments) - - def get_file(self, selector: Sequence[str], /) -> FileSegment | None: - segment = self.get(selector) - if isinstance(segment, FileSegment): - return segment - return None - - def get_by_prefix(self, prefix: str, /) -> Mapping[str, object]: - """Return a copy of all variables stored under the given node prefix.""" - - nodes = self.variable_dictionary.get(prefix) - if not nodes: - return {} - - result: dict[str, object] = {} - for key, variable in nodes.items(): - value = variable.value - result[key] = deepcopy(value) - - return result - - def _add_system_variables(self, system_variable: SystemVariable): - sys_var_mapping = system_variable.to_dict() - for key, value in sys_var_mapping.items(): - if value is None: - continue - selector = (SYSTEM_VARIABLE_NODE_ID, key) - # If the system variable already exists, do not add it again. - # This ensures that we can keep the id of the system variables intact. - if self._has(selector): - continue - self.add(selector, value) - - @classmethod - def empty(cls) -> VariablePool: - """Create an empty variable pool.""" - return cls(system_variables=SystemVariable.default()) diff --git a/api/dify_graph/system_variable.py b/api/dify_graph/system_variable.py deleted file mode 100644 index cc5deda892..0000000000 --- a/api/dify_graph/system_variable.py +++ /dev/null @@ -1,217 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping, Sequence -from types import MappingProxyType -from typing import Any -from uuid import uuid4 - -from pydantic import AliasChoices, BaseModel, ConfigDict, Field, model_validator - -from dify_graph.enums import SystemVariableKey -from dify_graph.file.models import File - - -class SystemVariable(BaseModel): - """A model for managing system variables. - - Fields with a value of `None` are treated as absent and will not be included - in the variable pool. - """ - - model_config = ConfigDict( - extra="forbid", - serialize_by_alias=True, - validate_by_alias=True, - ) - - user_id: str | None = None - - # Ideally, `app_id` and `workflow_id` should be required and not `None`. - # However, there are scenarios in the codebase where these fields are not set. - # To maintain compatibility, they are marked as optional here. - app_id: str | None = None - workflow_id: str | None = None - - timestamp: int | None = None - - files: Sequence[File] = Field(default_factory=list) - - # NOTE: The `workflow_execution_id` field was previously named `workflow_run_id`. - # To maintain compatibility with existing workflows, it must be serialized - # as `workflow_run_id` in dictionaries or JSON objects, and also referenced - # as `workflow_run_id` in the variable pool. - workflow_execution_id: str | None = Field( - validation_alias=AliasChoices("workflow_execution_id", "workflow_run_id"), - serialization_alias="workflow_run_id", - default=None, - ) - # Chatflow related fields. - query: str | None = None - conversation_id: str | None = None - dialogue_count: int | None = None - document_id: str | None = None - original_document_id: str | None = None - dataset_id: str | None = None - batch: str | None = None - datasource_type: str | None = None - datasource_info: Mapping[str, Any] | None = None - invoke_from: str | None = None - - @model_validator(mode="before") - @classmethod - def validate_json_fields(cls, data): - if isinstance(data, dict): - # For JSON validation, only allow workflow_run_id - if "workflow_execution_id" in data and "workflow_run_id" not in data: - # This is likely from direct instantiation, allow it - return data - elif "workflow_execution_id" in data and "workflow_run_id" in data: - # Both present, remove workflow_execution_id - data = data.copy() - data.pop("workflow_execution_id") - return data - return data - - @classmethod - def default(cls) -> SystemVariable: - return cls(workflow_execution_id=str(uuid4())) - - def to_dict(self) -> dict[SystemVariableKey, Any]: - # NOTE: This method is provided for compatibility with legacy code. - # New code should use the `SystemVariable` object directly instead of converting - # it to a dictionary, as this conversion results in the loss of type information - # for each key, making static analysis more difficult. - - d: dict[SystemVariableKey, Any] = { - SystemVariableKey.FILES: self.files, - } - if self.user_id is not None: - d[SystemVariableKey.USER_ID] = self.user_id - if self.app_id is not None: - d[SystemVariableKey.APP_ID] = self.app_id - if self.workflow_id is not None: - d[SystemVariableKey.WORKFLOW_ID] = self.workflow_id - if self.workflow_execution_id is not None: - d[SystemVariableKey.WORKFLOW_EXECUTION_ID] = self.workflow_execution_id - if self.query is not None: - d[SystemVariableKey.QUERY] = self.query - if self.conversation_id is not None: - d[SystemVariableKey.CONVERSATION_ID] = self.conversation_id - if self.dialogue_count is not None: - d[SystemVariableKey.DIALOGUE_COUNT] = self.dialogue_count - if self.document_id is not None: - d[SystemVariableKey.DOCUMENT_ID] = self.document_id - if self.original_document_id is not None: - d[SystemVariableKey.ORIGINAL_DOCUMENT_ID] = self.original_document_id - if self.dataset_id is not None: - d[SystemVariableKey.DATASET_ID] = self.dataset_id - if self.batch is not None: - d[SystemVariableKey.BATCH] = self.batch - if self.datasource_type is not None: - d[SystemVariableKey.DATASOURCE_TYPE] = self.datasource_type - if self.datasource_info is not None: - d[SystemVariableKey.DATASOURCE_INFO] = self.datasource_info - if self.invoke_from is not None: - d[SystemVariableKey.INVOKE_FROM] = self.invoke_from - if self.timestamp is not None: - d[SystemVariableKey.TIMESTAMP] = self.timestamp - return d - - def as_view(self) -> SystemVariableReadOnlyView: - return SystemVariableReadOnlyView(self) - - -class SystemVariableReadOnlyView: - """ - A read-only view of a SystemVariable that implements the ReadOnlySystemVariable protocol. - - This class wraps a SystemVariable instance and provides read-only access to all its fields. - It always reads the latest data from the wrapped instance and prevents any write operations. - """ - - def __init__(self, system_variable: SystemVariable) -> None: - """ - Initialize the read-only view with a SystemVariable instance. - - Args: - system_variable: The SystemVariable instance to wrap - """ - self._system_variable = system_variable - - @property - def user_id(self) -> str | None: - return self._system_variable.user_id - - @property - def app_id(self) -> str | None: - return self._system_variable.app_id - - @property - def workflow_id(self) -> str | None: - return self._system_variable.workflow_id - - @property - def workflow_execution_id(self) -> str | None: - return self._system_variable.workflow_execution_id - - @property - def query(self) -> str | None: - return self._system_variable.query - - @property - def conversation_id(self) -> str | None: - return self._system_variable.conversation_id - - @property - def dialogue_count(self) -> int | None: - return self._system_variable.dialogue_count - - @property - def document_id(self) -> str | None: - return self._system_variable.document_id - - @property - def original_document_id(self) -> str | None: - return self._system_variable.original_document_id - - @property - def dataset_id(self) -> str | None: - return self._system_variable.dataset_id - - @property - def batch(self) -> str | None: - return self._system_variable.batch - - @property - def datasource_type(self) -> str | None: - return self._system_variable.datasource_type - - @property - def invoke_from(self) -> str | None: - return self._system_variable.invoke_from - - @property - def files(self) -> Sequence[File]: - """ - Get a copy of the files from the wrapped SystemVariable. - - Returns: - A defensive copy of the files sequence to prevent modification - """ - return tuple(self._system_variable.files) # Convert to immutable tuple - - @property - def datasource_info(self) -> Mapping[str, Any] | None: - """ - Get a copy of the datasource info from the wrapped SystemVariable. - - Returns: - A view of the datasource info mapping to prevent modification - """ - if self._system_variable.datasource_info is None: - return None - return MappingProxyType(self._system_variable.datasource_info) - - def __repr__(self) -> str: - """Return a string representation of the read-only view.""" - return f"SystemVariableReadOnlyView(system_variable={self._system_variable!r})" diff --git a/api/dify_graph/utils/__init__.py b/api/dify_graph/utils/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/utils/condition/__init__.py b/api/dify_graph/utils/condition/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/api/dify_graph/utils/condition/entities.py b/api/dify_graph/utils/condition/entities.py deleted file mode 100644 index 77a214571a..0000000000 --- a/api/dify_graph/utils/condition/entities.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Sequence -from typing import Literal - -from pydantic import BaseModel, Field - -SupportedComparisonOperator = Literal[ - # for string or array - "contains", - "not contains", - "start with", - "end with", - "is", - "is not", - "empty", - "not empty", - "in", - "not in", - "all of", - # for number - "=", - "≠", - ">", - "<", - "≥", - "≤", - "null", - "not null", - # for file - "exists", - "not exists", -] - - -class SubCondition(BaseModel): - key: str - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | None = None - - -class SubVariableCondition(BaseModel): - logical_operator: Literal["and", "or"] - conditions: list[SubCondition] = Field(default_factory=list) - - -class Condition(BaseModel): - variable_selector: list[str] - comparison_operator: SupportedComparisonOperator - value: str | Sequence[str] | bool | None = None - sub_variable_condition: SubVariableCondition | None = None diff --git a/api/dify_graph/utils/condition/processor.py b/api/dify_graph/utils/condition/processor.py deleted file mode 100644 index dea72d96c2..0000000000 --- a/api/dify_graph/utils/condition/processor.py +++ /dev/null @@ -1,504 +0,0 @@ -import json -from collections.abc import Mapping, Sequence -from typing import Literal, NamedTuple - -from dify_graph.file import FileAttribute, file_manager -from dify_graph.runtime import VariablePool -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayBooleanSegment, BooleanSegment - -from .entities import Condition, SubCondition, SupportedComparisonOperator - - -def _convert_to_bool(value: object) -> bool: - if isinstance(value, int): - return bool(value) - - if isinstance(value, str): - loaded = json.loads(value) - if isinstance(loaded, (int, bool)): - return bool(loaded) - - raise TypeError(f"unexpected value: type={type(value)}, value={value}") - - -class ConditionCheckResult(NamedTuple): - inputs: Sequence[Mapping[str, object]] - group_results: Sequence[bool] - final_result: bool - - -class ConditionProcessor: - def process_conditions( - self, - *, - variable_pool: VariablePool, - conditions: Sequence[Condition], - operator: Literal["and", "or"], - ) -> ConditionCheckResult: - input_conditions: list[Mapping[str, object]] = [] - group_results: list[bool] = [] - - for condition in conditions: - variable = variable_pool.get(condition.variable_selector) - if variable is None: - raise ValueError(f"Variable {condition.variable_selector} not found") - - if isinstance(variable, ArrayFileSegment) and condition.comparison_operator in { - "contains", - "not contains", - "all of", - }: - # check sub conditions - if not condition.sub_variable_condition: - raise ValueError("Sub variable is required") - result = _process_sub_conditions( - variable=variable, - sub_conditions=condition.sub_variable_condition.conditions, - operator=condition.sub_variable_condition.logical_operator, - ) - elif condition.comparison_operator in { - "exists", - "not exists", - }: - result = _evaluate_condition( - value=variable.value, - operator=condition.comparison_operator, - expected=None, - ) - else: - actual_value = variable.value if variable else None - expected_value: str | Sequence[str] | bool | list[bool] | None = condition.value - if isinstance(expected_value, str): - expected_value = variable_pool.convert_template(expected_value).text - # Here we need to explicit convet the input string to boolean. - if isinstance(variable, (BooleanSegment, ArrayBooleanSegment)) and expected_value is not None: - # The following two lines is for compatibility with existing workflows. - if isinstance(expected_value, list): - expected_value = [_convert_to_bool(i) for i in expected_value] - else: - expected_value = _convert_to_bool(expected_value) - input_conditions.append( - { - "actual_value": actual_value, - "expected_value": expected_value, - "comparison_operator": condition.comparison_operator, - } - ) - result = _evaluate_condition( - value=actual_value, - operator=condition.comparison_operator, - expected=expected_value, - ) - group_results.append(result) - # Implemented short-circuit evaluation for logical conditions - if (operator == "and" and not result) or (operator == "or" and result): - final_result = result - return ConditionCheckResult(input_conditions, group_results, final_result) - - final_result = all(group_results) if operator == "and" else any(group_results) - return ConditionCheckResult(input_conditions, group_results, final_result) - - -def _evaluate_condition( - *, - operator: SupportedComparisonOperator, - value: object, - expected: str | Sequence[str] | bool | Sequence[bool] | None, -) -> bool: - match operator: - case "contains": - return _assert_contains(value=value, expected=expected) - case "not contains": - return _assert_not_contains(value=value, expected=expected) - case "start with": - return _assert_start_with(value=value, expected=expected) - case "end with": - return _assert_end_with(value=value, expected=expected) - case "is": - return _assert_is(value=value, expected=expected) - case "is not": - return _assert_is_not(value=value, expected=expected) - case "empty": - return _assert_empty(value=value) - case "not empty": - return _assert_not_empty(value=value) - case "=": - return _assert_equal(value=value, expected=expected) - case "≠": - return _assert_not_equal(value=value, expected=expected) - case ">": - return _assert_greater_than(value=value, expected=expected) - case "<": - return _assert_less_than(value=value, expected=expected) - case "≥": - return _assert_greater_than_or_equal(value=value, expected=expected) - case "≤": - return _assert_less_than_or_equal(value=value, expected=expected) - case "null": - return _assert_null(value=value) - case "not null": - return _assert_not_null(value=value) - case "in": - return _assert_in(value=value, expected=expected) - case "not in": - return _assert_not_in(value=value, expected=expected) - case "all of" if isinstance(expected, list): - # Type narrowing: at this point expected is a list, could be list[str] or list[bool] - if all(isinstance(item, str) for item in expected): - # Create a new typed list to satisfy type checker - str_list: list[str] = [item for item in expected if isinstance(item, str)] - return _assert_all_of(value=value, expected=str_list) - elif all(isinstance(item, bool) for item in expected): - # Create a new typed list to satisfy type checker - bool_list: list[bool] = [item for item in expected if isinstance(item, bool)] - return _assert_all_of_bool(value=value, expected=bool_list) - else: - raise ValueError("all of operator expects homogeneous list of strings or booleans") - case "exists": - return _assert_exists(value=value) - case "not exists": - return _assert_not_exists(value=value) - case _: - raise ValueError(f"Unsupported operator: {operator}") - - -def _assert_contains(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected not in value: - return False - else: # value is list - if expected not in value: - return False - return True - - -def _assert_not_contains(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(value, (str, list)): - raise ValueError("Invalid actual value type: string or array") - - # Type checking ensures value is str or list at this point - if isinstance(value, str): - if not isinstance(expected, str): - expected = str(expected) - if expected in value: - return False - else: # value is list - if expected in value: - return False - return True - - -def _assert_start_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for startswith") - if not value.startswith(expected): - return False - return True - - -def _assert_end_with(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(value, str): - raise ValueError("Invalid actual value type: string") - - if not isinstance(expected, str): - raise ValueError("Expected value must be a string for endswith") - if not value.endswith(expected): - return False - return True - - -def _assert_is(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value != expected: - return False - return True - - -def _assert_is_not(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (str, bool)): - raise ValueError("Invalid actual value type: string or boolean") - - if value == expected: - return False - return True - - -def _assert_empty(*, value: object) -> bool: - if not value: - return True - return False - - -def _assert_not_empty(*, value: object) -> bool: - if value: - return True - return False - - -def _normalize_numeric_values(value: int | float, expected: object) -> tuple[int | float, int | float]: - """ - Normalize value and expected to compatible numeric types for comparison. - - Args: - value: The actual numeric value (int or float) - expected: The expected value (int, float, or str) - - Returns: - A tuple of (normalized_value, normalized_expected) with compatible types - - Raises: - ValueError: If expected cannot be converted to a number - """ - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to number") - - # Convert expected to appropriate numeric type - if isinstance(expected, str): - # Try to convert to float first to handle decimal strings - try: - expected_float = float(expected) - except ValueError as e: - raise ValueError(f"Cannot convert '{expected}' to number") from e - - # If value is int and expected is a whole number, keep as int comparison - if isinstance(value, int) and expected_float.is_integer(): - return value, int(expected_float) - else: - # Otherwise convert value to float for comparison - return float(value) if isinstance(value, int) else value, expected_float - elif isinstance(expected, float): - # If expected is already float, convert int value to float - return float(value) if isinstance(value, int) else value, expected - else: - # expected is int - return value, expected - - -def _assert_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value != expected: - return False - return True - - -def _assert_not_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float, bool)): - raise ValueError("Invalid actual value type: number or boolean") - - # Handle boolean comparison - if isinstance(value, bool): - if not isinstance(expected, (bool, int, str)): - raise ValueError(f"Cannot convert {type(expected)} to bool") - expected = bool(expected) - elif isinstance(value, int): - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to int") - expected = int(expected) - else: - if not isinstance(expected, (int, float, str)): - raise ValueError(f"Cannot convert {type(expected)} to float") - expected = float(expected) - - if value == expected: - return False - return True - - -def _assert_greater_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value > expected - - -def _assert_less_than(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value < expected - - -def _assert_greater_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value >= expected - - -def _assert_less_than_or_equal(*, value: object, expected: object) -> bool: - if value is None: - return False - - if not isinstance(value, (int, float)): - raise ValueError("Invalid actual value type: number") - - value, expected = _normalize_numeric_values(value, expected) - return value <= expected - - -def _assert_null(*, value: object) -> bool: - if value is None: - return True - return False - - -def _assert_not_null(*, value: object) -> bool: - if value is not None: - return True - return False - - -def _assert_in(*, value: object, expected: object) -> bool: - if not value: - return False - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value not in expected: - return False - return True - - -def _assert_not_in(*, value: object, expected: object) -> bool: - if not value: - return True - - if not isinstance(expected, list): - raise ValueError("Invalid expected value type: array") - - if value in expected: - return False - return True - - -def _assert_all_of(*, value: object, expected: Sequence[str]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set, str)): - return False - - return all(item in value for item in expected) - - -def _assert_all_of_bool(*, value: object, expected: Sequence[bool]) -> bool: - if not value: - return False - - # Ensure value is a container that supports 'in' operator - if not isinstance(value, (list, tuple, set)): - return False - - return all(item in value for item in expected) - - -def _assert_exists(*, value: object) -> bool: - return value is not None - - -def _assert_not_exists(*, value: object) -> bool: - return value is None - - -def _process_sub_conditions( - variable: ArrayFileSegment, - sub_conditions: Sequence[SubCondition], - operator: Literal["and", "or"], -) -> bool: - files = variable.value - group_results: list[bool] = [] - for condition in sub_conditions: - key = FileAttribute(condition.key) - values = [file_manager.get_attr(file=file, attr=key) for file in files] - expected_value = condition.value - if key == FileAttribute.EXTENSION: - if not isinstance(expected_value, str): - raise TypeError("Expected value must be a string when key is FileAttribute.EXTENSION") - if expected_value and not expected_value.startswith("."): - expected_value = "." + expected_value - - normalized_values: list[object] = [] - for value in values: - if value and isinstance(value, str): - if not value.startswith("."): - value = "." + value - normalized_values.append(value) - values = normalized_values - sub_group_results: list[bool] = [ - _evaluate_condition( - value=value, - operator=condition.comparison_operator, - expected=expected_value, - ) - for value in values - ] - # Determine the result based on the presence of "not" in the comparison operator - result = all(sub_group_results) if "not" in condition.comparison_operator else any(sub_group_results) - group_results.append(result) - return all(group_results) if operator == "and" else any(group_results) diff --git a/api/dify_graph/variable_loader.py b/api/dify_graph/variable_loader.py deleted file mode 100644 index d263450334..0000000000 --- a/api/dify_graph/variable_loader.py +++ /dev/null @@ -1,83 +0,0 @@ -import abc -from collections.abc import Mapping, Sequence -from typing import Any, Protocol - -from dify_graph.runtime import VariablePool -from dify_graph.variables import VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH - - -class VariableLoader(Protocol): - """Interface for loading variables based on selectors. - - A `VariableLoader` is responsible for retrieving additional variables required during the execution - of a single node, which are not provided as user inputs. - - NOTE(QuantumGhost): Typically, all variables loaded by a `VariableLoader` should belong to the same - application and share the same `app_id`. However, this interface does not enforce that constraint, - and the `app_id` parameter is intentionally omitted from `load_variables` to achieve separation of - concern and allow for flexible implementations. - - Implementations of `VariableLoader` should almost always have an `app_id` parameter in - their constructor. - - TODO(QuantumGhost): this is a temporally workaround. If we can move the creation of node instance into - `WorkflowService.single_step_run`, we may get rid of this interface. - """ - - @abc.abstractmethod - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - """Load variables based on the provided selectors. If the selectors are empty, - this method should return an empty list. - - The order of the returned variables is not guaranteed. If the caller wants to ensure - a specific order, they should sort the returned list themselves. - - :param: selectors: a list of string list, each inner list should have at least two elements: - - the first element is the node ID, - - the second element is the variable name. - :return: a list of VariableBase objects that match the provided selectors. - """ - pass - - -class _DummyVariableLoader(VariableLoader): - """A dummy implementation of VariableLoader that does not load any variables. - Serves as a placeholder when no variable loading is needed. - """ - - def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]: - return [] - - -DUMMY_VARIABLE_LOADER = _DummyVariableLoader() - - -def load_into_variable_pool( - variable_loader: VariableLoader, - variable_pool: VariablePool, - variable_mapping: Mapping[str, Sequence[str]], - user_inputs: Mapping[str, Any], -): - # Loading missing variable from draft var here, and set it into - # variable_pool. - variables_to_load: list[list[str]] = [] - for key, selector in variable_mapping.items(): - # NOTE(QuantumGhost): this logic needs to be in sync with - # `WorkflowEntry.mapping_user_inputs_to_variable_pool`. - node_variable_list = key.split(".") - if len(node_variable_list) < 2: - raise ValueError(f"Invalid variable key: {key}. It should have at least two elements.") - if key in user_inputs: - continue - node_variable_key = ".".join(node_variable_list[1:]) - if node_variable_key in user_inputs: - continue - if variable_pool.get(selector) is None: - variables_to_load.append(list(selector)) - loaded = variable_loader.load_variables(variables_to_load) - for var in loaded: - assert len(var.selector) >= SELECTORS_LENGTH, f"Invalid variable {var}" - # Add variable directly to the pool - # The variable pool expects 2-element selectors [node_id, variable_name] - variable_pool.add([var.selector[0], var.selector[1]], var) diff --git a/api/dify_graph/variables/__init__.py b/api/dify_graph/variables/__init__.py deleted file mode 100644 index be3fc8d97a..0000000000 --- a/api/dify_graph/variables/__init__.py +++ /dev/null @@ -1,70 +0,0 @@ -from .input_entities import VariableEntity, VariableEntityType -from .segment_group import SegmentGroup -from .segments import ( - ArrayAnySegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from .types import SegmentType -from .variables import ( - ArrayAnyVariable, - ArrayFileVariable, - ArrayNumberVariable, - ArrayObjectVariable, - ArrayStringVariable, - ArrayVariable, - FileVariable, - FloatVariable, - IntegerVariable, - NoneVariable, - ObjectVariable, - SecretVariable, - StringVariable, - Variable, - VariableBase, -) - -__all__ = [ - "ArrayAnySegment", - "ArrayAnyVariable", - "ArrayFileSegment", - "ArrayFileVariable", - "ArrayNumberSegment", - "ArrayNumberVariable", - "ArrayObjectSegment", - "ArrayObjectVariable", - "ArraySegment", - "ArrayStringSegment", - "ArrayStringVariable", - "ArrayVariable", - "FileSegment", - "FileVariable", - "FloatSegment", - "FloatVariable", - "IntegerSegment", - "IntegerVariable", - "NoneSegment", - "NoneVariable", - "ObjectSegment", - "ObjectVariable", - "SecretVariable", - "Segment", - "SegmentGroup", - "SegmentType", - "StringSegment", - "StringVariable", - "Variable", - "VariableBase", - "VariableEntity", - "VariableEntityType", -] diff --git a/api/dify_graph/variables/consts.py b/api/dify_graph/variables/consts.py deleted file mode 100644 index 8f3f78f740..0000000000 --- a/api/dify_graph/variables/consts.py +++ /dev/null @@ -1,7 +0,0 @@ -# The minimal selector length for valid variables. -# -# The first element of the selector is the node id, and the second element is the variable name. -# -# If the selector length is more than 2, the remaining parts are the keys / indexes paths used -# to extract part of the variable value. -SELECTORS_LENGTH = 2 diff --git a/api/dify_graph/variables/exc.py b/api/dify_graph/variables/exc.py deleted file mode 100644 index 5cf67c3bac..0000000000 --- a/api/dify_graph/variables/exc.py +++ /dev/null @@ -1,2 +0,0 @@ -class VariableError(ValueError): - pass diff --git a/api/dify_graph/variables/input_entities.py b/api/dify_graph/variables/input_entities.py deleted file mode 100644 index e6a68ea359..0000000000 --- a/api/dify_graph/variables/input_entities.py +++ /dev/null @@ -1,62 +0,0 @@ -from collections.abc import Sequence -from enum import StrEnum -from typing import Any - -from jsonschema import Draft7Validator, SchemaError -from pydantic import BaseModel, Field, field_validator - -from dify_graph.file import FileTransferMethod, FileType - - -class VariableEntityType(StrEnum): - TEXT_INPUT = "text-input" - SELECT = "select" - PARAGRAPH = "paragraph" - NUMBER = "number" - EXTERNAL_DATA_TOOL = "external_data_tool" - FILE = "file" - FILE_LIST = "file-list" - CHECKBOX = "checkbox" - JSON_OBJECT = "json_object" - - -class VariableEntity(BaseModel): - """ - Shared variable entity used by workflow runtime and app configuration. - """ - - # `variable` records the name of the variable in user inputs. - variable: str - label: str - description: str = "" - type: VariableEntityType - required: bool = False - hide: bool = False - default: Any = None - max_length: int | None = None - options: Sequence[str] = Field(default_factory=list) - allowed_file_types: Sequence[FileType] | None = Field(default_factory=list) - allowed_file_extensions: Sequence[str] | None = Field(default_factory=list) - allowed_file_upload_methods: Sequence[FileTransferMethod] | None = Field(default_factory=list) - json_schema: dict[str, Any] | None = Field(default=None) - - @field_validator("description", mode="before") - @classmethod - def convert_none_description(cls, value: Any) -> str: - return value or "" - - @field_validator("options", mode="before") - @classmethod - def convert_none_options(cls, value: Any) -> Sequence[str]: - return value or [] - - @field_validator("json_schema") - @classmethod - def validate_json_schema(cls, schema: dict[str, Any] | None) -> dict[str, Any] | None: - if schema is None: - return None - try: - Draft7Validator.check_schema(schema) - except SchemaError as error: - raise ValueError(f"Invalid JSON schema: {error.message}") - return schema diff --git a/api/dify_graph/variables/segment_group.py b/api/dify_graph/variables/segment_group.py deleted file mode 100644 index b363255b2c..0000000000 --- a/api/dify_graph/variables/segment_group.py +++ /dev/null @@ -1,22 +0,0 @@ -from .segments import Segment -from .types import SegmentType - - -class SegmentGroup(Segment): - value_type: SegmentType = SegmentType.GROUP - value: list[Segment] - - @property - def text(self): - return "".join([segment.text for segment in self.value]) - - @property - def log(self): - return "".join([segment.log for segment in self.value]) - - @property - def markdown(self): - return "".join([segment.markdown for segment in self.value]) - - def to_object(self): - return [segment.to_object() for segment in self.value] diff --git a/api/dify_graph/variables/segments.py b/api/dify_graph/variables/segments.py deleted file mode 100644 index bdb213ed48..0000000000 --- a/api/dify_graph/variables/segments.py +++ /dev/null @@ -1,253 +0,0 @@ -import json -import sys -from collections.abc import Mapping, Sequence -from typing import Annotated, Any, TypeAlias - -from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator - -from dify_graph.file import File - -from .types import SegmentType - - -class Segment(BaseModel): - """Segment is runtime type used during the execution of workflow. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - model_config = ConfigDict(frozen=True) - - value_type: SegmentType - value: Any - - @field_validator("value_type") - @classmethod - def validate_value_type(cls, value): - """ - This validator checks if the provided value is equal to the default value of the 'value_type' field. - If the value is different, a ValueError is raised. - """ - if value != cls.model_fields["value_type"].default: - raise ValueError("Cannot modify 'value_type'") - return value - - @property - def text(self) -> str: - return str(self.value) - - @property - def log(self) -> str: - return str(self.value) - - @property - def markdown(self) -> str: - return str(self.value) - - @property - def size(self) -> int: - """ - Return the size of the value in bytes. - """ - return sys.getsizeof(self.value) - - def to_object(self): - return self.value - - -class NoneSegment(Segment): - value_type: SegmentType = SegmentType.NONE - value: None = None - - @property - def text(self) -> str: - return "" - - @property - def log(self) -> str: - return "" - - @property - def markdown(self) -> str: - return "" - - -class StringSegment(Segment): - value_type: SegmentType = SegmentType.STRING - value: str - - -class FloatSegment(Segment): - value_type: SegmentType = SegmentType.FLOAT - value: float - # NOTE(QuantumGhost): seems that the equality for FloatSegment with `NaN` value has some problems. - # The following tests cannot pass. - # - # def test_float_segment_and_nan(): - # nan = float("nan") - # assert nan != nan - # - # f1 = FloatSegment(value=float("nan")) - # f2 = FloatSegment(value=float("nan")) - # assert f1 != f2 - # - # f3 = FloatSegment(value=nan) - # f4 = FloatSegment(value=nan) - # assert f3 != f4 - - -class IntegerSegment(Segment): - value_type: SegmentType = SegmentType.INTEGER - value: int - - -class ObjectSegment(Segment): - value_type: SegmentType = SegmentType.OBJECT - value: Mapping[str, Any] - - @property - def text(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False) - - @property - def log(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - @property - def markdown(self) -> str: - return json.dumps(self.model_dump()["value"], ensure_ascii=False, indent=2) - - -class ArraySegment(Segment): - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return super().text - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(f"- {item}") - return "\n".join(items) - - -class FileSegment(Segment): - value_type: SegmentType = SegmentType.FILE - value: File - - @property - def markdown(self) -> str: - return self.value.markdown - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class BooleanSegment(Segment): - value_type: SegmentType = SegmentType.BOOLEAN - value: bool - - -class ArrayAnySegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_ANY - value: Sequence[Any] - - -class ArrayStringSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_STRING - value: Sequence[str] - - @property - def text(self) -> str: - # Return empty string for empty arrays instead of "[]" - if not self.value: - return "" - return json.dumps(self.value, ensure_ascii=False) - - -class ArrayNumberSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_NUMBER - value: Sequence[float | int] - - -class ArrayObjectSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_OBJECT - value: Sequence[Mapping[str, Any]] - - -class ArrayFileSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_FILE - value: Sequence[File] - - @property - def markdown(self) -> str: - items = [] - for item in self.value: - items.append(item.markdown) - return "\n".join(items) - - @property - def log(self) -> str: - return "" - - @property - def text(self) -> str: - return "" - - -class ArrayBooleanSegment(ArraySegment): - value_type: SegmentType = SegmentType.ARRAY_BOOLEAN - value: Sequence[bool] - - -def get_segment_discriminator(v: Any) -> SegmentType | None: - if isinstance(v, Segment): - return v.value_type - elif isinstance(v, dict): - value_type = v.get("value_type") - if value_type is None: - return None - try: - seg_type = SegmentType(value_type) - except ValueError: - return None - return seg_type - else: - # return None if the discriminator value isn't found - return None - - -# The `SegmentUnion`` type is used to enable serialization and deserialization with Pydantic. -# Use `Segment` for type hinting when serialization is not required. -# -# Note: -# - All variants in `SegmentUnion` must inherit from the `Segment` class. -# - The union must include all non-abstract subclasses of `Segment`, except: -# - `SegmentGroup`, which is not added to the variable pool. -# - `VariableBase` and its subclasses, which are handled by `Variable`. -SegmentUnion: TypeAlias = Annotated[ - ( - Annotated[NoneSegment, Tag(SegmentType.NONE)] - | Annotated[StringSegment, Tag(SegmentType.STRING)] - | Annotated[FloatSegment, Tag(SegmentType.FLOAT)] - | Annotated[IntegerSegment, Tag(SegmentType.INTEGER)] - | Annotated[ObjectSegment, Tag(SegmentType.OBJECT)] - | Annotated[FileSegment, Tag(SegmentType.FILE)] - | Annotated[BooleanSegment, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnySegment, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringSegment, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberSegment, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/dify_graph/variables/types.py b/api/dify_graph/variables/types.py deleted file mode 100644 index 53bf495a27..0000000000 --- a/api/dify_graph/variables/types.py +++ /dev/null @@ -1,273 +0,0 @@ -from __future__ import annotations - -from collections.abc import Mapping -from enum import StrEnum -from typing import TYPE_CHECKING, Any - -from dify_graph.file.models import File - -if TYPE_CHECKING: - from dify_graph.variables.segments import Segment - - -class ArrayValidation(StrEnum): - """Strategy for validating array elements. - - Note: - The `NONE` and `FIRST` strategies are primarily for compatibility purposes. - Avoid using them in new code whenever possible. - """ - - # Skip element validation (only check array container) - NONE = "none" - - # Validate the first element (if array is non-empty) - FIRST = "first" - - # Validate all elements in the array. - ALL = "all" - - -class SegmentType(StrEnum): - NUMBER = "number" - INTEGER = "integer" - FLOAT = "float" - STRING = "string" - OBJECT = "object" - SECRET = "secret" - - FILE = "file" - BOOLEAN = "boolean" - - ARRAY_ANY = "array[any]" - ARRAY_STRING = "array[string]" - ARRAY_NUMBER = "array[number]" - ARRAY_OBJECT = "array[object]" - ARRAY_FILE = "array[file]" - ARRAY_BOOLEAN = "array[boolean]" - - NONE = "none" - - GROUP = "group" - - def is_array_type(self) -> bool: - return self in _ARRAY_TYPES - - @classmethod - def infer_segment_type(cls, value: Any) -> SegmentType | None: - """ - Attempt to infer the `SegmentType` based on the Python type of the `value` parameter. - - Returns `None` if no appropriate `SegmentType` can be determined for the given `value`. - For example, this may occur if the input is a generic Python object of type `object`. - """ - - if isinstance(value, list): - elem_types: set[SegmentType] = set() - for i in value: - segment_type = cls.infer_segment_type(i) - if segment_type is None: - return None - - elem_types.add(segment_type) - - if len(elem_types) != 1: - if elem_types.issubset(_NUMERICAL_TYPES): - return SegmentType.ARRAY_NUMBER - return SegmentType.ARRAY_ANY - elif all(i.is_array_type() for i in elem_types): - return SegmentType.ARRAY_ANY - match elem_types.pop(): - case SegmentType.STRING: - return SegmentType.ARRAY_STRING - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return SegmentType.ARRAY_NUMBER - case SegmentType.OBJECT: - return SegmentType.ARRAY_OBJECT - case SegmentType.FILE: - return SegmentType.ARRAY_FILE - case SegmentType.NONE: - return SegmentType.ARRAY_ANY - case SegmentType.BOOLEAN: - return SegmentType.ARRAY_BOOLEAN - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - if value is None: - return SegmentType.NONE - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif isinstance(value, bool): - return SegmentType.BOOLEAN - elif isinstance(value, int): - return SegmentType.INTEGER - elif isinstance(value, float): - return SegmentType.FLOAT - elif isinstance(value, str): - return SegmentType.STRING - elif isinstance(value, dict): - return SegmentType.OBJECT - elif isinstance(value, File): - return SegmentType.FILE - else: - return None - - def _validate_array(self, value: Any, array_validation: ArrayValidation) -> bool: - if not isinstance(value, list): - return False - # Skip element validation if array is empty - if len(value) == 0: - return True - if self == SegmentType.ARRAY_ANY: - return True - element_type = _ARRAY_ELEMENT_TYPES_MAPPING[self] - - if array_validation == ArrayValidation.NONE: - return True - elif array_validation == ArrayValidation.FIRST: - return element_type.is_valid(value[0]) - else: - return all(element_type.is_valid(i, array_validation=ArrayValidation.NONE) for i in value) - - def is_valid(self, value: Any, array_validation: ArrayValidation = ArrayValidation.ALL) -> bool: - """ - Check if a value matches the segment type. - Users of `SegmentType` should call this method, instead of using - `isinstance` manually. - - Args: - value: The value to validate - array_validation: Validation strategy for array types (ignored for non-array types) - - Returns: - True if the value matches the type under the given validation strategy - """ - if self.is_array_type(): - return self._validate_array(value, array_validation) - # Important: The check for `bool` must precede the check for `int`, - # as `bool` is a subclass of `int` in Python's type hierarchy. - elif self == SegmentType.BOOLEAN: - return isinstance(value, bool) - elif self in [SegmentType.INTEGER, SegmentType.FLOAT, SegmentType.NUMBER]: - return isinstance(value, (int, float)) - elif self == SegmentType.STRING: - return isinstance(value, str) - elif self == SegmentType.OBJECT: - return isinstance(value, dict) - elif self == SegmentType.SECRET: - return isinstance(value, str) - elif self == SegmentType.FILE: - return isinstance(value, File) - elif self == SegmentType.NONE: - return value is None - elif self == SegmentType.GROUP: - from .segment_group import SegmentGroup - from .segments import Segment - - if isinstance(value, SegmentGroup): - return all(isinstance(item, Segment) for item in value.value) - - if isinstance(value, list): - return all(isinstance(item, Segment) for item in value) - - return False - else: - raise AssertionError("this statement should be unreachable.") - - @staticmethod - def cast_value(value: Any, type_: SegmentType): - # Cast Python's `bool` type to `int` when the runtime type requires - # an integer or number. - # - # This ensures compatibility with existing workflows that may use `bool` as - # `int`, since in Python's type system, `bool` is a subtype of `int`. - # - # This function exists solely to maintain compatibility with existing workflows. - # It should not be used to compromise the integrity of the runtime type system. - # No additional casting rules should be introduced to this function. - - if type_ in ( - SegmentType.INTEGER, - SegmentType.NUMBER, - ) and isinstance(value, bool): - return int(value) - if type_ == SegmentType.ARRAY_NUMBER and all(isinstance(i, bool) for i in value): - return [int(i) for i in value] - return value - - def exposed_type(self) -> SegmentType: - """Returns the type exposed to the frontend. - - The frontend treats `INTEGER` and `FLOAT` as `NUMBER`, so these are returned as `NUMBER` here. - """ - if self in (SegmentType.INTEGER, SegmentType.FLOAT): - return SegmentType.NUMBER - return self - - def element_type(self) -> SegmentType | None: - """Return the element type of the current segment type, or `None` if the element type is undefined. - - Raises: - ValueError: If the current segment type is not an array type. - - Note: - For certain array types, such as `SegmentType.ARRAY_ANY`, their element types are not defined - by the runtime system. In such cases, this method will return `None`. - """ - if not self.is_array_type(): - raise ValueError(f"element_type is only supported by array type, got {self}") - return _ARRAY_ELEMENT_TYPES_MAPPING.get(self) - - @staticmethod - def get_zero_value(t: SegmentType) -> Segment: - # Lazy import to avoid circular dependency - from factories import variable_factory - - match t: - case ( - SegmentType.ARRAY_OBJECT - | SegmentType.ARRAY_ANY - | SegmentType.ARRAY_STRING - | SegmentType.ARRAY_NUMBER - | SegmentType.ARRAY_BOOLEAN - ): - return variable_factory.build_segment_with_type(t, []) - case SegmentType.OBJECT: - return variable_factory.build_segment({}) - case SegmentType.STRING: - return variable_factory.build_segment("") - case SegmentType.INTEGER: - return variable_factory.build_segment(0) - case SegmentType.FLOAT: - return variable_factory.build_segment(0.0) - case SegmentType.NUMBER: - return variable_factory.build_segment(0) - case SegmentType.BOOLEAN: - return variable_factory.build_segment(False) - case _: - raise ValueError(f"unsupported variable type: {t}") - - -_ARRAY_ELEMENT_TYPES_MAPPING: Mapping[SegmentType, SegmentType] = { - # ARRAY_ANY does not have corresponding element type. - SegmentType.ARRAY_STRING: SegmentType.STRING, - SegmentType.ARRAY_NUMBER: SegmentType.NUMBER, - SegmentType.ARRAY_OBJECT: SegmentType.OBJECT, - SegmentType.ARRAY_FILE: SegmentType.FILE, - SegmentType.ARRAY_BOOLEAN: SegmentType.BOOLEAN, -} - -_ARRAY_TYPES = frozenset( - list(_ARRAY_ELEMENT_TYPES_MAPPING.keys()) - + [ - SegmentType.ARRAY_ANY, - ] -) - -_NUMERICAL_TYPES = frozenset( - [ - SegmentType.NUMBER, - SegmentType.INTEGER, - SegmentType.FLOAT, - ] -) diff --git a/api/dify_graph/variables/utils.py b/api/dify_graph/variables/utils.py deleted file mode 100644 index 8e738f8fd5..0000000000 --- a/api/dify_graph/variables/utils.py +++ /dev/null @@ -1,33 +0,0 @@ -from collections.abc import Iterable, Sequence -from typing import Any - -import orjson - -from .segment_group import SegmentGroup -from .segments import ArrayFileSegment, FileSegment, Segment - - -def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]: - selectors = [node_id, name] - if paths: - selectors.extend(paths) - return selectors - - -def segment_orjson_default(o: Any): - """Default function for orjson serialization of Segment types""" - if isinstance(o, ArrayFileSegment): - return [v.model_dump() for v in o.value] - elif isinstance(o, FileSegment): - return o.value.model_dump() - elif isinstance(o, SegmentGroup): - return [segment_orjson_default(seg) for seg in o.value] - elif isinstance(o, Segment): - return o.value - raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable") - - -def dumps_with_segments(obj: Any, ensure_ascii: bool = False) -> str: - """JSON dumps with segment support using orjson""" - option = orjson.OPT_NON_STR_KEYS - return orjson.dumps(obj, default=segment_orjson_default, option=option).decode("utf-8") diff --git a/api/dify_graph/variables/variables.py b/api/dify_graph/variables/variables.py deleted file mode 100644 index af866283da..0000000000 --- a/api/dify_graph/variables/variables.py +++ /dev/null @@ -1,172 +0,0 @@ -from collections.abc import Sequence -from typing import Annotated, Any, TypeAlias -from uuid import uuid4 - -from pydantic import BaseModel, Discriminator, Field, Tag - -from .segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, - get_segment_discriminator, -) -from .types import SegmentType - - -def _obfuscated_token(token: str) -> str: - if not token: - return token - if len(token) <= 8: - return "*" * 20 - return token[:6] + "*" * 12 + token[-2:] - - -class VariableBase(Segment): - """ - A variable is a segment that has a name. - - It is mainly used to store segments and their selector in VariablePool. - - Note: this class is abstract, you should use subclasses of this class instead. - """ - - id: str = Field( - default_factory=lambda: str(uuid4()), - description="Unique identity for variable.", - ) - name: str - description: str = Field(default="", description="Description of the variable.") - selector: Sequence[str] = Field(default_factory=list) - - -class StringVariable(StringSegment, VariableBase): - pass - - -class FloatVariable(FloatSegment, VariableBase): - pass - - -class IntegerVariable(IntegerSegment, VariableBase): - pass - - -class ObjectVariable(ObjectSegment, VariableBase): - pass - - -class ArrayVariable(ArraySegment, VariableBase): - pass - - -class ArrayAnyVariable(ArrayAnySegment, ArrayVariable): - pass - - -class ArrayStringVariable(ArrayStringSegment, ArrayVariable): - pass - - -class ArrayNumberVariable(ArrayNumberSegment, ArrayVariable): - pass - - -class ArrayObjectVariable(ArrayObjectSegment, ArrayVariable): - pass - - -class SecretVariable(StringVariable): - value_type: SegmentType = SegmentType.SECRET - - @property - def log(self) -> str: - return _obfuscated_token(self.value) - - -class NoneVariable(NoneSegment, VariableBase): - value_type: SegmentType = SegmentType.NONE - value: None = None - - -class FileVariable(FileSegment, VariableBase): - pass - - -class BooleanVariable(BooleanSegment, VariableBase): - pass - - -class ArrayFileVariable(ArrayFileSegment, ArrayVariable): - pass - - -class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable): - pass - - -class RAGPipelineVariable(BaseModel): - belong_to_node_id: str = Field(description="belong to which node id, shared means public") - type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list") - label: str = Field(description="label") - description: str | None = Field(description="description", default="") - variable: str = Field(description="variable key", default="") - max_length: int | None = Field( - description="max length, applicable to text-input, paragraph, and file-list", default=0 - ) - default_value: Any = Field(description="default value", default="") - placeholder: str | None = Field(description="placeholder", default="") - unit: str | None = Field(description="unit, applicable to Number", default="") - tooltips: str | None = Field(description="helpful text", default="") - allowed_file_types: list[str] | None = Field( - description="image, document, audio, video, custom.", default_factory=list - ) - allowed_file_extensions: list[str] | None = Field(description="e.g. ['.jpg', '.mp3']", default_factory=list) - allowed_file_upload_methods: list[str] | None = Field( - description="remote_url, local_file, tool_file.", default_factory=list - ) - required: bool = Field(description="optional, default false", default=False) - options: list[str] | None = Field(default_factory=list) - - -class RAGPipelineVariableInput(BaseModel): - variable: RAGPipelineVariable - value: Any - - -# The `Variable` type is used to enable serialization and deserialization with Pydantic. -# Use `VariableBase` for type hinting when serialization is not required. -# -# Note: -# - All variants in `Variable` must inherit from the `VariableBase` class. -# - The union must include all non-abstract subclasses of `VariableBase`. -Variable: TypeAlias = Annotated[ - ( - Annotated[NoneVariable, Tag(SegmentType.NONE)] - | Annotated[StringVariable, Tag(SegmentType.STRING)] - | Annotated[FloatVariable, Tag(SegmentType.FLOAT)] - | Annotated[IntegerVariable, Tag(SegmentType.INTEGER)] - | Annotated[ObjectVariable, Tag(SegmentType.OBJECT)] - | Annotated[FileVariable, Tag(SegmentType.FILE)] - | Annotated[BooleanVariable, Tag(SegmentType.BOOLEAN)] - | Annotated[ArrayAnyVariable, Tag(SegmentType.ARRAY_ANY)] - | Annotated[ArrayStringVariable, Tag(SegmentType.ARRAY_STRING)] - | Annotated[ArrayNumberVariable, Tag(SegmentType.ARRAY_NUMBER)] - | Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)] - | Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)] - | Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)] - | Annotated[SecretVariable, Tag(SegmentType.SECRET)] - ), - Discriminator(get_segment_discriminator), -] diff --git a/api/dify_graph/workflow_type_encoder.py b/api/dify_graph/workflow_type_encoder.py deleted file mode 100644 index 3dd846b3cb..0000000000 --- a/api/dify_graph/workflow_type_encoder.py +++ /dev/null @@ -1,49 +0,0 @@ -from collections.abc import Mapping -from decimal import Decimal -from typing import Any, overload - -from pydantic import BaseModel - -from dify_graph.file.models import File -from dify_graph.variables import Segment - - -class WorkflowRuntimeTypeConverter: - @overload - def to_json_encodable(self, value: Mapping[str, Any]) -> Mapping[str, Any]: ... - @overload - def to_json_encodable(self, value: None) -> None: ... - - def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None: - """Convert runtime values to JSON-serializable structures.""" - - result = self.value_to_json_encodable_recursive(value) - if isinstance(result, Mapping) or result is None: - return result - return {} - - def value_to_json_encodable_recursive(self, value: Any): - if value is None: - return value - if isinstance(value, (bool, int, str, float)): - return value - if isinstance(value, Decimal): - # Convert Decimal to float for JSON serialization - return float(value) - if isinstance(value, Segment): - return self.value_to_json_encodable_recursive(value.value) - if isinstance(value, File): - return value.to_dict() - if isinstance(value, BaseModel): - return value.model_dump(mode="json") - if isinstance(value, dict): - res = {} - for k, v in value.items(): - res[k] = self.value_to_json_encodable_recursive(v) - return res - if isinstance(value, list): - res_list = [] - for item in value: - res_list.append(self.value_to_json_encodable_recursive(item)) - return res_list - return value diff --git a/api/dify_graph/__init__.py b/api/enterprise/__init__.py similarity index 100% rename from api/dify_graph/__init__.py rename to api/enterprise/__init__.py diff --git a/api/enterprise/telemetry/DATA_DICTIONARY.md b/api/enterprise/telemetry/DATA_DICTIONARY.md new file mode 100644 index 0000000000..60d482cd1c --- /dev/null +++ b/api/enterprise/telemetry/DATA_DICTIONARY.md @@ -0,0 +1,525 @@ +# Dify Enterprise Telemetry Data Dictionary + +Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md). + +## Resource Attributes + +Attached to every signal (Span, Metric, Log). + +| Attribute | Type | Example | +|-----------|------|---------| +| `service.name` | string | `dify` | +| `host.name` | string | `dify-api-7f8b` | + +## Traces (Spans) + +### `dify.workflow.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID (Workflow Run ID) | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Unique ID for this run | +| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. | +| `dify.workflow.error` | string | Error message if failed | +| `dify.workflow.elapsed_time` | float | Total execution time (seconds) | +| `dify.invoke_from` | string | `api`, `webapp`, `debug` | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.message.id` | string | Message ID (optional) | +| `dify.invoked_by` | string | User ID who triggered the run | +| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) | +| `gen_ai.user.id` | string | End-user identifier (optional) | +| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) | +| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) | +| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) | +| `dify.parent.app.id` | string | Parent app ID (optional) | + +### `dify.node.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.trace_id` | string | Business trace ID | +| `dify.tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.workflow.id` | string | Workflow definition ID | +| `dify.workflow.run_id` | string | Workflow Run ID | +| `dify.message.id` | string | Message ID (optional) | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.node.execution_id` | string | Unique node execution ID | +| `dify.node.id` | string | Node ID in workflow graph | +| `dify.node.type` | string | Node type (see appendix) | +| `dify.node.title` | string | Display title | +| `dify.node.status` | string | `succeeded`, `failed` | +| `dify.node.error` | string | Error message if failed | +| `dify.node.elapsed_time` | float | Execution time (seconds) | +| `dify.node.index` | int | Execution order index | +| `dify.node.predecessor_node_id` | string | Triggering node ID | +| `dify.node.iteration_id` | string | Iteration ID (optional) | +| `dify.node.loop_id` | string | Loop ID (optional) | +| `dify.node.parallel_id` | string | Parallel branch ID (optional) | +| `dify.node.invoked_by` | string | User ID who triggered execution | +| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) | +| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) | +| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) | +| `gen_ai.request.model` | string | LLM model name (LLM nodes only) | +| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) | +| `gen_ai.user.id` | string | End-user identifier (optional) | + +### `dify.node.execution.draft` + +Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs. + +## Counters + +All counters are cumulative and emitted at 100% accuracy. + +### Token Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.tokens.total` | `{token}` | Total tokens consumed | +| `dify.tokens.input` | `{token}` | Input (prompt) tokens | +| `dify.tokens.output` | `{token}` | Output (completion) tokens | + +**Labels:** + +- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution) + +⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting. + +#### Token Hierarchy & Query Patterns + +Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting: + +``` +App-level total +├── workflow ← sum of all node_execution tokens (DO NOT add both) +│ └── node_execution ← per-node breakdown +├── message ← independent (non-workflow chat apps only) +├── rule_generate ← independent helper LLM call +├── code_generate ← independent helper LLM call +├── structured_output ← independent helper LLM call +└── instruction_modify← independent helper LLM call +``` + +**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both. + +**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`. +App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries. + +**Common queries** (PromQL): + +```promql +# ── Totals ────────────────────────────────────────────────── +# App-level total (exclude node_execution to avoid double-counting) +sum by (app_id) (dify_tokens_total{operation_type!="node_execution"}) + +# Single app total +sum (dify_tokens_total{app_id="", operation_type!="node_execution"}) + +# Per-tenant totals +sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"}) + +# ── Drill-down ────────────────────────────────────────────── +# Workflow-level tokens for an app +sum (dify_tokens_total{app_id="", operation_type="workflow"}) + +# Node-level breakdown within an app +sum by (node_type) (dify_tokens_total{app_id="", operation_type="node_execution"}) + +# Model breakdown for an app +sum by (model_provider, model_name) (dify_tokens_total{app_id=""}) + +# Input vs output per model +sum by (model_name) (dify_tokens_input_total{app_id=""}) +sum by (model_name) (dify_tokens_output_total{app_id=""}) + +# ── Rates ─────────────────────────────────────────────────── +# Token consumption rate (per hour) +sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h])) + +# Per-app consumption rate +sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h])) +``` + +**Finding `app_id` from app name** (trace query — Tempo / Jaeger): + +``` +{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id) +``` + +### Request Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.requests.total` | `{request}` | Total operations count | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `moderation` | `tenant_id`, `app_id` | +| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dataset_retrieval` | `tenant_id`, `app_id` | +| `generate_name` | `tenant_id`, `app_id` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` | + +### Error Counters + +| Metric | Unit | Description | +|--------|------|-------------| +| `dify.errors.total` | `{error}` | Total failed operations | + +**Labels by type:** + +| `type` | Additional Labels | +|--------|-------------------| +| `workflow` | `tenant_id`, `app_id` | +| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` | +| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `tool` | `tenant_id`, `app_id`, `tool_name` | +| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +### Other Counters + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` | +| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` | +| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` | +| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` | +| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` | + +## Histograms + +| Metric | Unit | Labels | +|--------|------|--------| +| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` | +| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` | +| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` | +| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` | +| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` | + +## Structured Logs + +### Span Companion Logs + +Logs that accompany spans. Signal type: `span_detail` + +#### `dify.workflow.run` Companion Log + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.workflow.version` | string | Yes | Workflow definition version | +| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) | +| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) | +| `dify.workflow.query` | string | No | User query text (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.workflow.run"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs + +**Common attributes:** All span attributes (see Traces section) plus: + +| Additional Attribute | Type | Always Present | Description | +|---------------------|------|----------------|-------------| +| `dify.app.name` | string | No | Application display name | +| `dify.workspace.name` | string | No | Workspace display name | +| `dify.invoke_from` | string | No | Invocation source | +| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) | +| `dify.node.total_price` | float | No | Cost (LLM nodes only) | +| `dify.node.currency` | string | No | Currency code (LLM nodes only) | +| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) | +| `dify.node.loop_index` | int | No | Loop index (loop nodes) | +| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) | +| `dify.credential.name` | string | No | Credential name (plugin nodes) | +| `dify.credential.id` | string | No | Credential ID (plugin nodes) | +| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) | +| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) | +| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) | +| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) | +| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) | + +**Event attributes:** + +- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"` +- `dify.event.signal`: `"span_detail"` +- `trace_id`, `span_id`, `tenant_id`, `user_id` + +### Standalone Logs + +Logs without structural spans. Signal type: `metric_only` + +#### `dify.message.run` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.message.run"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID (32-char hex) | +| `span_id` | string | OTEL span ID (16-char hex) | +| `tenant_id` | string | Tenant identifier | +| `user_id` | string | User identifier (optional) | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.conversation.id` | string | Conversation ID (optional) | +| `dify.workflow.run_id` | string | Workflow run ID (optional) | +| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.message.status` | string | `succeeded`, `failed` | +| `dify.message.error` | string | Error message (if failed) | +| `dify.message.duration` | float | Duration (seconds) | +| `dify.message.time_to_first_token` | float | TTFT (seconds) | +| `dify.message.inputs` | string/JSON | Inputs (content-gated) | +| `dify.message.outputs` | string/JSON | Outputs (content-gated) | + +#### `dify.tool.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.tool.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.tool.name` | string | Tool name | +| `dify.tool.duration` | float | Duration (seconds) | +| `dify.tool.status` | string | `succeeded`, `failed` | +| `dify.tool.error` | string | Error message (if failed) | +| `dify.tool.inputs` | string/JSON | Inputs (content-gated) | +| `dify.tool.outputs` | string/JSON | Outputs (content-gated) | +| `dify.tool.parameters` | string/JSON | Parameters (content-gated) | +| `dify.tool.config` | string/JSON | Configuration (content-gated) | + +#### `dify.moderation.check` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.moderation.check"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.moderation.type` | string | `input`, `output` | +| `dify.moderation.action` | string | `pass`, `block`, `flag` | +| `dify.moderation.flagged` | boolean | Whether flagged | +| `dify.moderation.categories` | JSON array | Flagged categories | +| `dify.moderation.query` | string | Content (content-gated) | + +#### `dify.suggested_question.generation` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.suggested_question.generation"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.suggested_question.count` | int | Number of questions | +| `dify.suggested_question.duration` | float | Duration (seconds) | +| `dify.suggested_question.status` | string | `succeeded`, `failed` | +| `dify.suggested_question.error` | string | Error message (if failed) | +| `dify.suggested_question.questions` | JSON array | Questions (content-gated) | + +#### `dify.dataset.retrieval` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.dataset.retrieval"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.dataset.id` | string | Dataset identifier | +| `dify.dataset.name` | string | Dataset name | +| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) | +| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) | +| `dify.retrieval.rerank_provider` | string | Rerank model provider | +| `dify.retrieval.rerank_model` | string | Rerank model name | +| `dify.retrieval.query` | string | Search query (content-gated) | +| `dify.retrieval.document_count` | int | Documents retrieved | +| `dify.retrieval.duration` | float | Duration (seconds) | +| `dify.retrieval.status` | string | `succeeded`, `failed` | +| `dify.retrieval.error` | string | Error message (if failed) | +| `dify.dataset.documents` | JSON array | Documents (content-gated) | + +#### `dify.generate_name.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.generate_name.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.conversation.id` | string | Conversation identifier | +| `dify.generate_name.duration` | float | Duration (seconds) | +| `dify.generate_name.status` | string | `succeeded`, `failed` | +| `dify.generate_name.error` | string | Error message (if failed) | +| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) | +| `dify.generate_name.outputs` | string | Generated name (content-gated) | + +#### `dify.prompt_generation.execution` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.prompt_generation.execution"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) | +| `gen_ai.provider.name` | string | LLM provider | +| `gen_ai.request.model` | string | LLM model | +| `gen_ai.usage.input_tokens` | int | Input tokens | +| `gen_ai.usage.output_tokens` | int | Output tokens | +| `gen_ai.usage.total_tokens` | int | Total tokens | +| `dify.prompt_generation.duration` | float | Duration (seconds) | +| `dify.prompt_generation.status` | string | `succeeded`, `failed` | +| `dify.prompt_generation.error` | string | Error message (if failed) | +| `dify.prompt_generation.instruction` | string | Instruction (content-gated) | +| `dify.prompt_generation.output` | string/JSON | Output (content-gated) | + +#### `dify.app.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` | +| `dify.app.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.updated` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.updated"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.updated_at` | string | Timestamp (ISO 8601) | + +#### `dify.app.deleted` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.app.deleted"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.app.deleted_at` | string | Timestamp (ISO 8601) | + +#### `dify.feedback.created` + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.feedback.created"` | +| `dify.event.signal` | string | `"metric_only"` | +| `trace_id` | string | OTEL trace ID | +| `span_id` | string | OTEL span ID | +| `tenant_id` | string | Tenant identifier | +| `dify.app_id` | string | Application identifier | +| `dify.message.id` | string | Message identifier | +| `dify.feedback.rating` | string | `like`, `dislike`, `null` | +| `dify.feedback.content` | string | Feedback text (content-gated) | +| `dify.feedback.created_at` | string | Timestamp (ISO 8601) | + +#### `dify.telemetry.rehydration_failed` + +Diagnostic event for telemetry system health monitoring. + +| Attribute | Type | Description | +|-----------|------|-------------| +| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` | +| `dify.event.signal` | string | `"metric_only"` | +| `tenant_id` | string | Tenant identifier | +| `dify.telemetry.error` | string | Error message | +| `dify.telemetry.payload_type` | string | Payload type (see appendix) | +| `dify.telemetry.correlation_id` | string | Correlation ID | + +## Content-Gated Attributes + +When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`). + +| Attribute | Signal | +|-----------|--------| +| `dify.workflow.inputs` | `dify.workflow.run` | +| `dify.workflow.outputs` | `dify.workflow.run` | +| `dify.workflow.query` | `dify.workflow.run` | +| `dify.node.inputs` | `dify.node.execution` | +| `dify.node.outputs` | `dify.node.execution` | +| `dify.node.process_data` | `dify.node.execution` | +| `dify.message.inputs` | `dify.message.run` | +| `dify.message.outputs` | `dify.message.run` | +| `dify.tool.inputs` | `dify.tool.execution` | +| `dify.tool.outputs` | `dify.tool.execution` | +| `dify.tool.parameters` | `dify.tool.execution` | +| `dify.tool.config` | `dify.tool.execution` | +| `dify.moderation.query` | `dify.moderation.check` | +| `dify.suggested_question.questions` | `dify.suggested_question.generation` | +| `dify.retrieval.query` | `dify.dataset.retrieval` | +| `dify.dataset.documents` | `dify.dataset.retrieval` | +| `dify.generate_name.inputs` | `dify.generate_name.execution` | +| `dify.generate_name.outputs` | `dify.generate_name.execution` | +| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` | +| `dify.prompt_generation.output` | `dify.prompt_generation.execution` | +| `dify.feedback.content` | `dify.feedback.created` | + +## Appendix + +### Operation Types + +- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify` + +### Node Types + +- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input` + +### Workflow Statuses + +- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused` + +### Payload Types + +- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback` + +### Null Value Behavior + +**Spans:** Attributes with `null` values are omitted. + +**Logs:** Attributes with `null` values appear as `null` in JSON. + +**Content-Gated:** Replaced with reference strings, not set to `null`. diff --git a/api/enterprise/telemetry/README.md b/api/enterprise/telemetry/README.md new file mode 100644 index 0000000000..e43c0b1ea2 --- /dev/null +++ b/api/enterprise/telemetry/README.md @@ -0,0 +1,121 @@ +# Dify Enterprise Telemetry + +This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb. + +## Overview + +Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage. + +- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes). +- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`. +- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking. + +### Signal Architecture + +```mermaid +graph TD + A[Workflow Run] -->|Span| B(dify.workflow.run) + A -->|Log| C(dify.workflow.run detail) + B ---|trace_id| C + + D[Node Execution] -->|Span| E(dify.node.execution) + D -->|Log| F(dify.node.execution detail) + E ---|span_id| F + + G[Message/Tool/etc] -->|Log| H(dify.* event) + G -->|Metric| I(dify.* counter/histogram) +``` + +## Configuration + +The Enterprise OTEL exporter is configured via environment variables. + +| Variable | Description | Default | +|----------|-------------|---------| +| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` | +| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` | +| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - | +| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - | +| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` | +| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - | +| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `false` | +| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` | +| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` | + +## Correlation Model + +Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks. + +### ID Generation Rules + +- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))` +- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)` + +### Scenario A: Simple Workflow + +A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`). + +``` +trace_id = UUID(workflow_run_id) +├── [root span] dify.workflow.run (span_id = hash(workflow_run_id)) +│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1)) +│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2)) +│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3)) +``` + +### Scenario B: Nested Sub-Workflow + +A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id. + +``` +trace_id = UUID(outer_workflow_run_id) ← shared across both workflows +├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id)) +│ ├── dify.node.execution - "Start Node" +│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow) +│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id)) +│ │ ├── dify.node.execution - "Inner Start" +│ │ └── dify.node.execution - "Inner End" +│ └── dify.node.execution - "End Node" +``` + +**Key attributes for nested workflows:** + +- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id` +- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id` +- Inner workflow's `dify.parent.app.id` = outer `app_id` + +### Scenario C: Draft Node Execution + +A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root. + +``` +trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow +└── dify.node.execution.draft (span_id = hash(node_execution_id)) +``` + +**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace. + +## Content Gating + +When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector. + +**Reference String Format:** + +``` +ref:{id_type}={uuid} +``` + +**Examples:** + +``` +ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000 +ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001 +ref:message_id=770e8400-e29b-41d4-a716-446655440002 +``` + +To retrieve actual content when gating is enabled, query the Dify database using the provided UUID. + +## Reference + +For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md). diff --git a/api/dify_graph/graph_engine/entities/__init__.py b/api/enterprise/telemetry/__init__.py similarity index 100% rename from api/dify_graph/graph_engine/entities/__init__.py rename to api/enterprise/telemetry/__init__.py diff --git a/api/enterprise/telemetry/contracts.py b/api/enterprise/telemetry/contracts.py new file mode 100644 index 0000000000..91398cb8cb --- /dev/null +++ b/api/enterprise/telemetry/contracts.py @@ -0,0 +1,73 @@ +"""Telemetry gateway contracts and data structures. + +This module defines the envelope format for telemetry events and the routing +configuration that determines how each event type is processed. +""" + +from __future__ import annotations + +from enum import StrEnum +from typing import Any + +from pydantic import BaseModel, ConfigDict + + +class TelemetryCase(StrEnum): + """Enumeration of all known telemetry event cases.""" + + WORKFLOW_RUN = "workflow_run" + NODE_EXECUTION = "node_execution" + DRAFT_NODE_EXECUTION = "draft_node_execution" + MESSAGE_RUN = "message_run" + TOOL_EXECUTION = "tool_execution" + MODERATION_CHECK = "moderation_check" + SUGGESTED_QUESTION = "suggested_question" + DATASET_RETRIEVAL = "dataset_retrieval" + GENERATE_NAME = "generate_name" + PROMPT_GENERATION = "prompt_generation" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + FEEDBACK_CREATED = "feedback_created" + + +class SignalType(StrEnum): + """Signal routing type for telemetry cases.""" + + TRACE = "trace" + METRIC_LOG = "metric_log" + + +class CaseRoute(BaseModel): + """Routing configuration for a telemetry case. + + Attributes: + signal_type: The type of signal (trace or metric_log). + ce_eligible: Whether this case is eligible for community edition tracing. + """ + + signal_type: SignalType + ce_eligible: bool + + +class TelemetryEnvelope(BaseModel): + """Envelope for telemetry events. + + Attributes: + case: The telemetry case type. + tenant_id: The tenant identifier. + event_id: Unique event identifier for deduplication. + payload: The main event payload (inline for small payloads, + empty when offloaded to storage via ``payload_ref``). + metadata: Optional metadata dictionary. When the gateway + offloads a large payload to object storage, this contains + ``{"payload_ref": ""}``. + """ + + model_config = ConfigDict(extra="forbid", use_enum_values=False) + + case: TelemetryCase + tenant_id: str + event_id: str + payload: dict[str, Any] + metadata: dict[str, Any] | None = None diff --git a/api/enterprise/telemetry/draft_trace.py b/api/enterprise/telemetry/draft_trace.py new file mode 100644 index 0000000000..5a8d0ee6f4 --- /dev/null +++ b/api/enterprise/telemetry/draft_trace.py @@ -0,0 +1,90 @@ +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from graphon.enums import WorkflowNodeExecutionMetadataKey + +from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName +from core.telemetry import emit as telemetry_emit +from models.workflow import WorkflowNodeExecutionModel + + +def enqueue_draft_node_execution_trace( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, + user_id: str, +) -> None: + node_data = _build_node_execution_data( + execution=execution, + outputs=outputs, + workflow_execution_id=workflow_execution_id, + ) + telemetry_emit( + TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id=execution.tenant_id, + user_id=user_id, + app_id=execution.app_id, + ), + payload={"node_execution_data": node_data}, + ) + ) + + +def _build_node_execution_data( + *, + execution: WorkflowNodeExecutionModel, + outputs: Mapping[str, Any] | None, + workflow_execution_id: str | None, +) -> dict[str, Any]: + metadata = execution.execution_metadata_dict + node_outputs = outputs if outputs is not None else execution.outputs_dict + execution_id = workflow_execution_id or execution.workflow_run_id or execution.id + process_data = execution.process_data_dict or {} + + # Extract token breakdown from outputs.usage (set by LLM node) + usage: Mapping[str, Any] = {} + if isinstance(node_outputs, Mapping): + raw_usage = node_outputs.get("usage") + if isinstance(raw_usage, Mapping): + usage = raw_usage + + return { + "workflow_id": execution.workflow_id, + "workflow_execution_id": execution_id, + "tenant_id": execution.tenant_id, + "app_id": execution.app_id, + "node_execution_id": execution.id, + "node_id": execution.node_id, + "node_type": execution.node_type, + "title": execution.title, + "status": execution.status, + "error": execution.error, + "elapsed_time": execution.elapsed_time, + "index": execution.index, + "predecessor_node_id": execution.predecessor_node_id, + "created_at": execution.created_at, + "finished_at": execution.finished_at, + "total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0), + "total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0), + "currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY), + "model_provider": process_data.get("model_provider"), + "model_name": process_data.get("model_name"), + "prompt_tokens": usage.get("prompt_tokens"), + "completion_tokens": usage.get("completion_tokens"), + "tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name") + if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict) + else None, + "iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID), + "iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX), + "loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID), + "loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX), + "parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID), + "node_inputs": execution.inputs_dict, + "node_outputs": node_outputs, + "process_data": execution.process_data_dict, + } diff --git a/api/enterprise/telemetry/enterprise_trace.py b/api/enterprise/telemetry/enterprise_trace.py new file mode 100644 index 0000000000..fc17d9d93e --- /dev/null +++ b/api/enterprise/telemetry/enterprise_trace.py @@ -0,0 +1,966 @@ +"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass. + +Invoked directly in the Celery task, not through OpsTraceManager dispatch. +Only requires a matching ``trace(trace_info)`` method signature. + +Signal strategy: +- **Traces (spans)**: workflow run, node execution, draft node execution only. +- **Metrics + structured logs**: all other event types. + +Token metric labels (unified structure): +All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the +same label set for consistent filtering and aggregation: +- tenant_id: Tenant identifier +- app_id: Application identifier +- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.) +- model_provider: LLM provider name (empty string if not applicable) +- model_name: LLM model name (empty string if not applicable) +- node_type: Workflow node type (empty string if not node_execution) + +This unified structure allows filtering by operation_type to separate: +- Workflow-level aggregates (operation_type=workflow) +- Individual node executions (operation_type=node_execution) +- Direct message calls (operation_type=message) +- Prompt generation operations (operation_type=rule_generate, code_generate, etc.) + +Without this, tokens are double-counted when querying totals (workflow totals include +node totals, since workflow.total_tokens is the sum of all node tokens). +""" + +from __future__ import annotations + +import json +import logging +from typing import Any, cast + +from opentelemetry.util.types import AttributeValue + +from core.ops.entities.trace_entity import ( + BaseTraceInfo, + DatasetRetrievalTraceInfo, + DraftNodeExecutionTrace, + GenerateNameTraceInfo, + MessageTraceInfo, + ModerationTraceInfo, + OperationType, + PromptGenerationTraceInfo, + SuggestedQuestionTraceInfo, + ToolTraceInfo, + WorkflowNodeTraceInfo, + WorkflowTraceInfo, +) +from enterprise.telemetry.entities import ( + EnterpriseTelemetryCounter, + EnterpriseTelemetryEvent, + EnterpriseTelemetryHistogram, + EnterpriseTelemetrySpan, + TokenMetricLabels, +) +from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log + +logger = logging.getLogger(__name__) + + +class EnterpriseOtelTrace: + """Duck-typed enterprise trace handler. + + ``*_trace`` methods emit spans (workflow/node only) or structured logs + (all other events), plus metrics at 100 % accuracy. + """ + + def __init__(self) -> None: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if exporter is None: + raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized") + self._exporter = exporter + + def trace(self, trace_info: BaseTraceInfo) -> None: + if isinstance(trace_info, WorkflowTraceInfo): + self._workflow_trace(trace_info) + elif isinstance(trace_info, MessageTraceInfo): + self._message_trace(trace_info) + elif isinstance(trace_info, ToolTraceInfo): + self._tool_trace(trace_info) + elif isinstance(trace_info, DraftNodeExecutionTrace): + self._draft_node_execution_trace(trace_info) + elif isinstance(trace_info, WorkflowNodeTraceInfo): + self._node_execution_trace(trace_info) + elif isinstance(trace_info, ModerationTraceInfo): + self._moderation_trace(trace_info) + elif isinstance(trace_info, SuggestedQuestionTraceInfo): + self._suggested_question_trace(trace_info) + elif isinstance(trace_info, DatasetRetrievalTraceInfo): + self._dataset_retrieval_trace(trace_info) + elif isinstance(trace_info, GenerateNameTraceInfo): + self._generate_name_trace(trace_info) + elif isinstance(trace_info, PromptGenerationTraceInfo): + self._prompt_generation_trace(trace_info) + else: + raise AssertionError("this statment should be unreachable") + + def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + metadata = self._metadata(trace_info) + tenant_id, app_id, user_id = self._context_ids(trace_info, metadata) + return { + "dify.trace_id": trace_info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "dify.message.id": trace_info.message_id, + } + + def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]: + return trace_info.metadata + + def _context_ids( + self, + trace_info: BaseTraceInfo, + metadata: dict[str, Any], + ) -> tuple[str | None, str | None, str | None]: + tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id") + app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id") + user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id") + return tenant_id, app_id, user_id + + def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]: + return dict(values) + + def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None: + if isinstance(value, str): + return value + if isinstance(value, dict): + return cast(dict[str, Any], value) + if isinstance(value, list): + items: list[object] = [] + for item in cast(list[object], value): + items.append(item) + return items + return None + + def _content_or_ref(self, value: Any, ref: str) -> Any: + if self._exporter.include_content: + return self._maybe_json(value) + return ref + + def _maybe_json(self, value: Any) -> str | None: + if value is None: + return None + if isinstance(value, str): + return value + try: + return json.dumps(value, default=str) + except (TypeError, ValueError): + return str(value) + + # ------------------------------------------------------------------ + # SPAN-emitting handlers (workflow, node execution, draft node) + # ------------------------------------------------------------------ + + def _workflow_trace(self, info: WorkflowTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.workflow.status": info.workflow_run_status, + "dify.workflow.error": info.error, + "dify.workflow.elapsed_time": info.workflow_run_elapsed_time, + "dify.invoke_from": metadata.get("triggered_from"), + "dify.conversation.id": info.conversation_id, + "dify.message.id": info.message_id, + "dify.invoked_by": info.invoked_by, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.user.id": user_id, + } + + trace_correlation_override, parent_span_id_source = info.resolved_parent_context + + parent_ctx = metadata.get("parent_trace_context") + if isinstance(parent_ctx, dict): + parent_ctx_dict = cast(dict[str, Any], parent_ctx) + span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id") + span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id") + span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id") + span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id") + + self._exporter.export_span( + EnterpriseTelemetrySpan.WORKFLOW_RUN, + span_attrs, + correlation_id=info.workflow_run_id, + span_id_source=info.workflow_run_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + parent_span_id_source=parent_span_id_source, + ) + + # -- Companion log: ALL attrs (span + detail) for full picture -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.workflow.version": info.workflow_run_version, + } + ) + + ref = f"ref:workflow_run_id={info.workflow_run_id}" + log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref) + log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref) + log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref) + + emit_telemetry_log( + event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.workflow_run_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + invoke_from = metadata.get("triggered_from", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="workflow", + status=info.workflow_run_status, + invoke_from=invoke_from, + ), + ) + # Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults + # to 0 in the DB and can be stale if the Celery write races with the trace task. + # start_time = workflow_run.created_at, end_time = workflow_run.finished_at. + if info.start_time and info.end_time: + workflow_duration = (info.end_time - info.start_time).total_seconds() + elif info.workflow_run_elapsed_time: + workflow_duration = float(info.workflow_run_elapsed_time) + else: + workflow_duration = 0.0 + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.WORKFLOW_DURATION, + workflow_duration, + self._labels( + **labels, + status=info.workflow_run_status, + ), + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="workflow", + ), + ) + + def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None: + self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node") + + def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None: + self._emit_node_execution_trace( + info, + EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION, + "draft_node", + correlation_id_override=info.node_execution_id, + trace_correlation_override_param=info.workflow_run_id, + ) + + def _emit_node_execution_trace( + self, + info: WorkflowNodeTraceInfo, + span_name: EnterpriseTelemetrySpan, + request_type: str, + correlation_id_override: str | None = None, + trace_correlation_override_param: str | None = None, + ) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + # -- Span attrs: identity + structure + status + timing + gen_ai scalars -- + span_attrs: dict[str, Any] = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "dify.app_id": app_id, + "dify.workflow.id": info.workflow_id, + "dify.workflow.run_id": info.workflow_run_id, + "dify.message.id": info.message_id, + "dify.conversation.id": metadata.get("conversation_id"), + "dify.node.execution_id": info.node_execution_id, + "dify.node.id": info.node_id, + "dify.node.type": info.node_type, + "dify.node.title": info.title, + "dify.node.status": info.status, + "dify.node.error": info.error, + "dify.node.elapsed_time": info.elapsed_time, + "dify.node.index": info.index, + "dify.node.predecessor_node_id": info.predecessor_node_id, + "dify.node.iteration_id": info.iteration_id, + "dify.node.loop_id": info.loop_id, + "dify.node.parallel_id": info.parallel_id, + "dify.node.invoked_by": info.invoked_by, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "gen_ai.request.model": info.model_name, + "gen_ai.provider.name": info.model_provider, + "gen_ai.user.id": user_id, + } + + resolved_override, _ = info.resolved_parent_context + trace_correlation_override = trace_correlation_override_param or resolved_override + + effective_correlation_id = correlation_id_override or info.workflow_run_id + self._exporter.export_span( + span_name, + span_attrs, + correlation_id=effective_correlation_id, + span_id_source=info.node_execution_id, + start_time=info.start_time, + end_time=info.end_time, + trace_correlation_override=trace_correlation_override, + ) + + # -- Companion log: ALL attrs (span + detail) -- + log_attrs: dict[str, Any] = {**span_attrs} + log_attrs.update( + { + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.invoke_from": metadata.get("invoke_from"), + "gen_ai.user.id": user_id, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.node.total_price": info.total_price, + "dify.node.currency": info.currency, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.tool.name": info.tool_name, + "dify.node.iteration_index": info.iteration_index, + "dify.node.loop_index": info.loop_index, + "dify.plugin.name": metadata.get("plugin_name"), + "dify.credential.name": metadata.get("credential_name"), + "dify.credential.id": metadata.get("credential_id"), + "dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")), + "dify.dataset.names": self._maybe_json(metadata.get("dataset_names")), + } + ) + + ref = f"ref:node_execution_id={info.node_execution_id}" + log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref) + log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref) + log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref) + + emit_telemetry_log( + event_name=span_name.value, + attributes=log_attrs, + signal="span_detail", + trace_id_source=info.workflow_run_id, + span_id_source=info.node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + # -- Metrics -- + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + node_type=info.node_type, + model_provider=info.model_provider or "", + ) + if info.total_tokens: + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.NODE_EXECUTION, + model_provider=info.model_provider or "", + model_name=info.model_name or "", + node_type=info.node_type, + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens is not None and info.prompt_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels + ) + if info.completion_tokens is not None and info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type=request_type, + status=info.status, + model_name=info.model_name or "", + ), + ) + duration_labels = dict(labels) + duration_labels["model_name"] = info.model_name or "" + plugin_name = metadata.get("plugin_name") + if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}: + duration_labels["plugin_name"] = plugin_name + self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type=request_type, + model_name=info.model_name or "", + ), + ) + + # ------------------------------------------------------------------ + # METRIC-ONLY handlers (structured log + counters/histograms) + # ------------------------------------------------------------------ + + def _message_trace(self, info: MessageTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.invoke_from": metadata.get("from_source"), + "dify.conversation.id": metadata.get("conversation_id"), + "dify.conversation.mode": info.conversation_mode, + "gen_ai.provider.name": metadata.get("ls_provider"), + "gen_ai.request.model": metadata.get("ls_model_name"), + "gen_ai.usage.input_tokens": info.message_tokens, + "gen_ai.usage.output_tokens": info.answer_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.message.status": metadata.get("status"), + "dify.message.error": info.error, + "dify.message.from_source": metadata.get("from_source"), + "dify.message.from_end_user_id": metadata.get("from_end_user_id"), + "dify.message.from_account_id": metadata.get("from_account_id"), + "dify.streaming": info.is_streaming_request, + "dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token, + "dify.message.streaming_duration": info.llm_streaming_time_to_generate, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + + if info.start_time and info.end_time: + attrs["dify.message.duration"] = (info.end_time - info.start_time).total_seconds() + + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MESSAGE_RUN, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + ) + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=OperationType.MESSAGE, + model_provider=metadata.get("ls_provider") or "", + model_name=metadata.get("ls_model_name") or "", + node_type="", + ).to_dict() + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.message_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels) + if info.answer_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels) + invoke_from = metadata.get("from_source", "") + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="message", + status=metadata.get("status", ""), + invoke_from=invoke_from, + ), + ) + + if info.start_time and info.end_time: + duration = (info.end_time - info.start_time).total_seconds() + self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels) + + if info.gen_ai_server_time_to_first_token is not None: + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="message", + ), + ) + + def _tool_trace(self, info: ToolTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.tool.name": info.tool_name, + "dify.tool.duration": float(info.time_cost), + "dify.tool.status": "failed" if info.error else "succeeded", + "dify.tool.error": info.error, + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + ref = f"ref:message_id={info.message_id}" + attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref) + attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref) + attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref) + attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + tool_name=info.tool_name, + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="tool", + ), + ) + + def _moderation_trace(self, info: ModerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs.update( + { + "dify.moderation.flagged": info.flagged, + "dify.moderation.action": info.action, + "dify.moderation.preset_response": info.preset_response, + "dify.moderation.type": metadata.get("moderation_type", "input"), + "dify.moderation.categories": self._maybe_json(metadata.get("moderation_categories", [])), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.moderation.query"] = self._content_or_ref( + info.query, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.MODERATION_CHECK, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="moderation", + ), + ) + + def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error = info.error or (info.metadata.get("error") if info.metadata else None) + status = "failed" if error else (info.status or "succeeded") + attrs.update( + { + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.suggested_question.status": status, + "dify.suggested_question.error": error, + "dify.suggested_question.duration": duration, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_id, + "dify.suggested_question.count": len(info.suggested_question), + "dify.workflow.run_id": metadata.get("workflow_run_id"), + } + ) + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + attrs["dify.suggested_question.questions"] = self._content_or_ref( + info.suggested_question, + f"ref:message_id={info.message_id}", + ) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="suggested_question", + model_provider=info.model_provider or "", + model_name=info.model_id or "", + ), + ) + + def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.retrieval.error"] = info.error + attrs["dify.retrieval.status"] = "failed" if info.error else "succeeded" + if info.start_time and info.end_time: + attrs["dify.retrieval.duration"] = (info.end_time - info.start_time).total_seconds() + attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id") + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + docs: list[dict[str, Any]] = [] + documents_any: Any = info.documents + documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else [] + for entry in documents_list: + if isinstance(entry, dict): + entry_dict: dict[str, Any] = cast(dict[str, Any], entry) + docs.append(entry_dict) + dataset_ids: list[str] = [] + dataset_names: list[str] = [] + structured_docs: list[dict[str, Any]] = [] + for doc in docs: + meta_raw = doc.get("metadata") + meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {} + did = meta.get("dataset_id") + dname = meta.get("dataset_name") + if did and did not in dataset_ids: + dataset_ids.append(did) + if dname and dname not in dataset_names: + dataset_names.append(dname) + structured_docs.append( + { + "dataset_id": did, + "document_id": meta.get("document_id"), + "segment_id": meta.get("segment_id"), + "score": meta.get("score"), + } + ) + + attrs["dify.dataset.id"] = self._maybe_json(dataset_ids) + attrs["dify.dataset.name"] = self._maybe_json(dataset_names) + attrs["dify.retrieval.document_count"] = len(docs) + + embedding_models_raw: Any = metadata.get("embedding_models") + embedding_models: dict[str, Any] = ( + cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {} + ) + if embedding_models: + providers: list[str] = [] + models: list[str] = [] + for ds_info in embedding_models.values(): + if isinstance(ds_info, dict): + ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info) + p = ds_info_dict.get("embedding_model_provider", "") + m = ds_info_dict.get("embedding_model", "") + if p and p not in providers: + providers.append(p) + if m and m not in models: + models.append(m) + attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers) + attrs["dify.dataset.embedding_models"] = self._maybe_json(models) + + # Add rerank model to logs + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + if rerank_provider or rerank_model: + attrs["dify.retrieval.rerank_provider"] = rerank_provider + attrs["dify.retrieval.rerank_model"] = rerank_model + + ref = f"ref:message_id={info.message_id}" + retrieval_inputs = self._safe_payload_value(info.inputs) + attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref) + attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL, + attributes=attrs, + trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None), + span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None), + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="dataset_retrieval", + ), + ) + + for did in dataset_ids: + # Get embedding model for this specific dataset + ds_embedding_info = embedding_models.get(did, {}) + embedding_provider = ds_embedding_info.get("embedding_model_provider", "") + embedding_model = ds_embedding_info.get("embedding_model", "") + + # Get rerank model (same for all datasets in this retrieval) + rerank_provider = metadata.get("rerank_model_provider", "") + rerank_model = metadata.get("rerank_model_name", "") + + self._exporter.increment_counter( + EnterpriseTelemetryCounter.DATASET_RETRIEVALS, + 1, + self._labels( + **labels, + dataset_id=did, + embedding_model_provider=embedding_provider, + embedding_model=embedding_model, + rerank_model_provider=rerank_provider, + rerank_model=rerank_model, + ), + ) + + def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = self._common_attrs(info) + attrs["dify.conversation.id"] = info.conversation_id + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + duration: float | None = None + if info.start_time is not None and info.end_time is not None: + duration = (info.end_time - info.start_time).total_seconds() + error: str | None = metadata.get("error") if metadata else None + status = "failed" if error else "succeeded" + attrs["dify.generate_name.duration"] = duration + attrs["dify.generate_name.status"] = status + attrs["dify.generate_name.error"] = error + + ref = f"ref:conversation_id={info.conversation_id}" + inputs = self._safe_payload_value(info.inputs) + outputs = self._safe_payload_value(info.outputs) + attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref) + attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + ) + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="generate_name", + ), + ) + + def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None: + metadata = self._metadata(info) + tenant_id, app_id, user_id = self._context_ids(info, metadata) + attrs = { + "dify.trace_id": info.resolved_trace_id, + "dify.tenant_id": tenant_id, + "gen_ai.user.id": user_id, + "dify.app_id": app_id or "", + "dify.app.name": metadata.get("app_name"), + "dify.workspace.name": metadata.get("workspace_name"), + "dify.prompt_generation.operation_type": info.operation_type, + "gen_ai.provider.name": info.model_provider, + "gen_ai.request.model": info.model_name, + "gen_ai.usage.input_tokens": info.prompt_tokens, + "gen_ai.usage.output_tokens": info.completion_tokens, + "gen_ai.usage.total_tokens": info.total_tokens, + "dify.prompt_generation.duration": info.latency, + "dify.prompt_generation.status": "failed" if info.error else "succeeded", + "dify.prompt_generation.error": info.error, + } + node_execution_id = metadata.get("node_execution_id") + if node_execution_id: + attrs["dify.node.execution_id"] = node_execution_id + + if info.total_price is not None: + attrs["dify.prompt_generation.total_price"] = info.total_price + attrs["dify.prompt_generation.currency"] = info.currency + + ref = f"ref:trace_id={info.trace_id}" + outputs = self._safe_payload_value(info.outputs) + attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref) + attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref) + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION, + attributes=attrs, + trace_id_source=info.resolved_trace_id, + span_id_source=node_execution_id, + tenant_id=tenant_id, + user_id=user_id, + ) + + token_labels = TokenMetricLabels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + node_type="", + ).to_dict() + + labels = self._labels( + tenant_id=tenant_id or "", + app_id=app_id or "", + operation_type=info.operation_type, + model_provider=info.model_provider, + model_name=info.model_name, + ) + + self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels) + if info.prompt_tokens > 0: + self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels) + if info.completion_tokens > 0: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels + ) + + prompt_status = "failed" if info.error else "succeeded" + self._exporter.increment_counter( + EnterpriseTelemetryCounter.REQUESTS, + 1, + self._labels( + **labels, + type="prompt_generation", + status=prompt_status, + ), + ) + + self._exporter.record_histogram( + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION, + info.latency, + labels, + ) + + if info.error: + self._exporter.increment_counter( + EnterpriseTelemetryCounter.ERRORS, + 1, + self._labels( + **labels, + type="prompt_generation", + ), + ) diff --git a/api/enterprise/telemetry/entities/__init__.py b/api/enterprise/telemetry/entities/__init__.py new file mode 100644 index 0000000000..4a9bd3dbf8 --- /dev/null +++ b/api/enterprise/telemetry/entities/__init__.py @@ -0,0 +1,121 @@ +from enum import StrEnum +from typing import cast + +from opentelemetry.util.types import AttributeValue +from pydantic import BaseModel, ConfigDict + + +class EnterpriseTelemetrySpan(StrEnum): + WORKFLOW_RUN = "dify.workflow.run" + NODE_EXECUTION = "dify.node.execution" + DRAFT_NODE_EXECUTION = "dify.node.execution.draft" + + +class EnterpriseTelemetryEvent(StrEnum): + """Event names for enterprise telemetry logs.""" + + APP_CREATED = "dify.app.created" + APP_UPDATED = "dify.app.updated" + APP_DELETED = "dify.app.deleted" + FEEDBACK_CREATED = "dify.feedback.created" + WORKFLOW_RUN = "dify.workflow.run" + MESSAGE_RUN = "dify.message.run" + TOOL_EXECUTION = "dify.tool.execution" + MODERATION_CHECK = "dify.moderation.check" + SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation" + DATASET_RETRIEVAL = "dify.dataset.retrieval" + GENERATE_NAME_EXECUTION = "dify.generate_name.execution" + PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution" + REHYDRATION_FAILED = "dify.telemetry.rehydration_failed" + + +class EnterpriseTelemetryCounter(StrEnum): + TOKENS = "tokens" + INPUT_TOKENS = "input_tokens" + OUTPUT_TOKENS = "output_tokens" + REQUESTS = "requests" + ERRORS = "errors" + FEEDBACK = "feedback" + DATASET_RETRIEVALS = "dataset_retrievals" + APP_CREATED = "app_created" + APP_UPDATED = "app_updated" + APP_DELETED = "app_deleted" + + +class EnterpriseTelemetryHistogram(StrEnum): + WORKFLOW_DURATION = "workflow_duration" + NODE_DURATION = "node_duration" + MESSAGE_DURATION = "message_duration" + MESSAGE_TTFT = "message_ttft" + TOOL_DURATION = "tool_duration" + PROMPT_GENERATION_DURATION = "prompt_generation_duration" + + +class TokenMetricLabels(BaseModel): + """Unified label structure for all dify.token.* metrics. + + All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST + use this exact label set to ensure consistent filtering and aggregation across + different operation types. + + Attributes: + tenant_id: Tenant identifier. + app_id: Application identifier. + operation_type: Source of token usage (workflow | node_execution | message | + rule_generate | code_generate | structured_output | instruction_modify). + model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level). + model_name: LLM model name. Empty string if not applicable (e.g., workflow-level). + node_type: Workflow node type. Empty string unless operation_type=node_execution. + + Usage: + labels = TokenMetricLabels( + tenant_id="tenant-123", + app_id="app-456", + operation_type=OperationType.WORKFLOW, + model_provider="", + model_name="", + node_type="", + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.INPUT_TOKENS, + 100, + labels.to_dict() + ) + + Design rationale: + Without this unified structure, tokens get double-counted when querying totals + because workflow.total_tokens is already the sum of all node tokens. The + operation_type label allows filtering to separate workflow-level aggregates from + node-level detail, while keeping the same label cardinality for consistent queries. + """ + + tenant_id: str + app_id: str + operation_type: str + model_provider: str + model_name: str + node_type: str + + model_config = ConfigDict(extra="forbid", frozen=True) + + def to_dict(self) -> dict[str, AttributeValue]: + return cast( + dict[str, AttributeValue], + { + "tenant_id": self.tenant_id, + "app_id": self.app_id, + "operation_type": self.operation_type, + "model_provider": self.model_provider, + "model_name": self.model_name, + "node_type": self.node_type, + }, + ) + + +__all__ = [ + "EnterpriseTelemetryCounter", + "EnterpriseTelemetryEvent", + "EnterpriseTelemetryHistogram", + "EnterpriseTelemetrySpan", + "TokenMetricLabels", +] diff --git a/api/enterprise/telemetry/event_handlers.py b/api/enterprise/telemetry/event_handlers.py new file mode 100644 index 0000000000..d8b4208c69 --- /dev/null +++ b/api/enterprise/telemetry/event_handlers.py @@ -0,0 +1,72 @@ +"""Blinker signal handlers for enterprise telemetry. + +Registered at import time via ``@signal.connect`` decorators. +Import must happen during ``ext_enterprise_telemetry.init_app()`` to +ensure handlers fire. Each handler delegates to ``core.telemetry.gateway`` +which handles routing, EE-gating, and dispatch. + +All handlers are best-effort: exceptions are caught and logged so that +telemetry failures never break user-facing operations. +""" + +from __future__ import annotations + +import logging + +from events.app_event import app_was_created, app_was_deleted, app_was_updated + +logger = logging.getLogger(__name__) + +__all__ = [ + "_handle_app_created", + "_handle_app_deleted", + "_handle_app_updated", +] + + +@app_was_created.connect +def _handle_app_created(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_CREATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={ + "app_id": getattr(sender, "id", None), + "mode": getattr(sender, "mode", None), + }, + ) + except Exception: + logger.warning("Failed to emit app_created telemetry", exc_info=True) + + +@app_was_updated.connect +def _handle_app_updated(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_UPDATED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_updated telemetry", exc_info=True) + + +@app_was_deleted.connect +def _handle_app_deleted(sender: object, **kwargs: object) -> None: + try: + from core.telemetry.gateway import emit as gateway_emit + from enterprise.telemetry.contracts import TelemetryCase + + gateway_emit( + case=TelemetryCase.APP_DELETED, + context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")}, + payload={"app_id": getattr(sender, "id", None)}, + ) + except Exception: + logger.warning("Failed to emit app_deleted telemetry", exc_info=True) diff --git a/api/enterprise/telemetry/exporter.py b/api/enterprise/telemetry/exporter.py new file mode 100644 index 0000000000..b2f860764f --- /dev/null +++ b/api/enterprise/telemetry/exporter.py @@ -0,0 +1,283 @@ +"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation. + +Uses dedicated TracerProvider and MeterProvider instances (configurable sampling, +independent from ext_otel.py infrastructure). + +Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py). +Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process. +""" + +import logging +import socket +import uuid +from datetime import UTC, datetime +from typing import Any, cast + +from opentelemetry import trace +from opentelemetry.baggage import get_all +from opentelemetry.baggage.propagation import W3CBaggagePropagator +from opentelemetry.context import Context +from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter +from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter +from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter +from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter +from opentelemetry.sdk.metrics import MeterProvider +from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader +from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio +from opentelemetry.semconv.resource import ResourceAttributes +from opentelemetry.trace import SpanContext, TraceFlags +from opentelemetry.util.types import Attributes, AttributeValue + +from configs import dify_config +from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram +from enterprise.telemetry.id_generator import ( + CorrelationIdGenerator, + compute_deterministic_span_id, + set_correlation_id, + set_span_id_source, +) + +logger = logging.getLogger(__name__) + + +def is_enterprise_telemetry_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def _parse_otlp_headers(raw: str) -> dict[str, str]: + ctx = W3CBaggagePropagator().extract({"baggage": raw}) + return {k: v for k, v in get_all(ctx).items() if isinstance(v, str)} + + +def _datetime_to_ns(dt: datetime) -> int: + """Convert a datetime to nanoseconds since epoch (OTEL convention).""" + # Ensure we always interpret naive datetimes as UTC instead of local time. + if dt.tzinfo is None: + dt = dt.replace(tzinfo=UTC) + else: + dt = dt.astimezone(UTC) + return int(dt.timestamp() * 1_000_000_000) + + +class _ExporterFactory: + def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool): + self._protocol = protocol + self._endpoint = endpoint + self._headers = headers + self._grpc_headers = tuple(headers.items()) if headers else None + self._http_headers = headers or None + self._insecure = insecure + + def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter: + if self._protocol == "grpc": + return GRPCSpanExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else "" + return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers) + + def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter: + if self._protocol == "grpc": + return GRPCMetricExporter( + endpoint=self._endpoint or None, + headers=self._grpc_headers, + insecure=self._insecure, + ) + metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else "" + return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers) + + +class EnterpriseExporter: + """Shared OTEL exporter for all enterprise telemetry. + + ``export_span`` creates spans with optional real timestamps, deterministic + span/trace IDs, and cross-workflow parent linking. + ``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy. + """ + + def __init__(self, config: object) -> None: + endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "") + headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "") + protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower() + service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify") + sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0) + self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True) + api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "") + + # Auto-detect TLS: https:// uses secure, everything else is insecure + insecure = not endpoint.startswith("https://") + + resource = Resource( + attributes={ + ResourceAttributes.SERVICE_NAME: service_name, + ResourceAttributes.HOST_NAME: socket.gethostname(), + } + ) + sampler = ParentBasedTraceIdRatio(sampling_rate) + id_generator = CorrelationIdGenerator() + self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator) + + headers = _parse_otlp_headers(headers_raw) + if api_key: + if "authorization" in headers: + logger.warning( + "ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains " + "'authorization'; the API key will take precedence." + ) + headers["authorization"] = f"Bearer {api_key}" + factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure) + + trace_exporter = factory.create_trace_exporter() + self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter)) + self._tracer = self._tracer_provider.get_tracer("dify.enterprise") + + metric_exporter = factory.create_metric_exporter() + self._meter_provider = MeterProvider( + resource=resource, + metric_readers=[PeriodicExportingMetricReader(metric_exporter)], + ) + meter = self._meter_provider.get_meter("dify.enterprise") + self._counters = { + EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"), + EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"), + EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"), + EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"), + EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"), + EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"), + EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter( + "dify.dataset.retrievals.total", unit="{retrieval}" + ), + EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"), + EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"), + } + self._histograms = { + EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"), + EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"), + EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram( + "dify.message.time_to_first_token", unit="s" + ), + EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"), + EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram( + "dify.prompt_generation.duration", unit="s" + ), + } + + def export_span( + self, + name: str, + attributes: dict[str, Any], + correlation_id: str | None = None, + span_id_source: str | None = None, + start_time: datetime | None = None, + end_time: datetime | None = None, + trace_correlation_override: str | None = None, + parent_span_id_source: str | None = None, + ) -> None: + """Export an OTEL span with optional deterministic IDs and real timestamps. + + Args: + name: Span operation name. + attributes: Span attributes dict. + correlation_id: Source for trace_id derivation (groups spans in one trace). + span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id). + start_time: Real span start time. When None, uses current time. + end_time: Real span end time. When None, span ends immediately. + trace_correlation_override: Override trace_id source (for cross-workflow linking). + When set, trace_id is derived from this instead of ``correlation_id``. + parent_span_id_source: Override parent span_id source (for cross-workflow linking). + When set, parent span_id is derived from this value. When None and + ``correlation_id`` is set, parent is the workflow root span. + """ + effective_trace_correlation = trace_correlation_override or correlation_id + set_correlation_id(effective_trace_correlation) + set_span_id_source(span_id_source) + + try: + parent_context: Context | None = None + # A span is the "root" of its correlation group when span_id_source == correlation_id + # (i.e. a workflow root span). All other spans are children. + if parent_span_id_source: + # Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow) + parent_span_id = compute_deterministic_span_id(parent_span_id_source) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0 + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for cross-workflow link: %s, span=%s", + effective_trace_correlation, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + elif correlation_id and correlation_id != span_id_source: + # Child span: parent is the correlation-group root (workflow root span) + parent_span_id = compute_deterministic_span_id(correlation_id) + try: + parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id)) + except (ValueError, AttributeError): + logger.warning( + "Invalid trace correlation UUID for child span link: %s, span=%s", + effective_trace_correlation or correlation_id, + name, + ) + parent_trace_id = 0 + if parent_trace_id: + parent_span_context = SpanContext( + trace_id=parent_trace_id, + span_id=parent_span_id, + is_remote=True, + trace_flags=TraceFlags(TraceFlags.SAMPLED), + ) + parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context)) + + span_start_time = _datetime_to_ns(start_time) if start_time is not None else None + span_end_on_exit = end_time is None + + with self._tracer.start_as_current_span( + name, + context=parent_context, + start_time=span_start_time, + end_on_exit=span_end_on_exit, + ) as span: + for key, value in attributes.items(): + if value is not None: + span.set_attribute(key, value) + if end_time is not None: + span.end(end_time=_datetime_to_ns(end_time)) + except Exception: + logger.exception("Failed to export span %s", name) + finally: + set_correlation_id(None) + set_span_id_source(None) + + def increment_counter( + self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue] + ) -> None: + counter = self._counters.get(name) + if counter: + counter.add(value, cast(Attributes, labels)) + + def record_histogram( + self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue] + ) -> None: + histogram = self._histograms.get(name) + if histogram: + histogram.record(value, cast(Attributes, labels)) + + def shutdown(self) -> None: + self._tracer_provider.shutdown() + self._meter_provider.shutdown() diff --git a/api/enterprise/telemetry/id_generator.py b/api/enterprise/telemetry/id_generator.py new file mode 100644 index 0000000000..f3e5d6d0d6 --- /dev/null +++ b/api/enterprise/telemetry/id_generator.py @@ -0,0 +1,75 @@ +"""Custom OTEL ID Generator for correlation-based trace/span ID derivation. + +Uses contextvars for thread-safe correlation_id -> trace_id mapping. +When a span_id_source is set, the span_id is derived deterministically +from that value, enabling any span to reference another as parent +without depending on span creation order. +""" + +import random +import uuid +from contextvars import ContextVar + +from opentelemetry.sdk.trace.id_generator import IdGenerator + +_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None) +_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None) + + +def set_correlation_id(correlation_id: str | None) -> None: + _correlation_id_context.set(correlation_id) + + +def get_correlation_id() -> str | None: + return _correlation_id_context.get() + + +def set_span_id_source(source_id: str | None) -> None: + """Set the source for deterministic span_id generation. + + When set, ``generate_span_id()`` derives the span_id from this value + (lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow + root spans or ``node_execution_id`` for node spans. + """ + _span_id_source_context.set(source_id) + + +def compute_deterministic_span_id(source_id: str) -> int: + """Derive a deterministic span_id from any UUID string. + + Uses the lower 64 bits of the UUID, guaranteeing non-zero output + (OTEL requires span_id != 0). + """ + span_id = uuid.UUID(source_id).int & ((1 << 64) - 1) + return span_id if span_id != 0 else 1 + + +class CorrelationIdGenerator(IdGenerator): + """ID generator that derives trace_id and optionally span_id from context. + + - trace_id: always derived from correlation_id (groups all spans in one trace) + - span_id: derived from span_id_source when set (enables deterministic + parent-child linking), otherwise random + """ + + def generate_trace_id(self) -> int: + correlation_id = _correlation_id_context.get() + if correlation_id: + try: + return uuid.UUID(correlation_id).int + except (ValueError, AttributeError): + pass + return random.getrandbits(128) + + def generate_span_id(self) -> int: + source = _span_id_source_context.get() + if source: + try: + return compute_deterministic_span_id(source) + except (ValueError, AttributeError): + pass + + span_id = random.getrandbits(64) + while span_id == 0: + span_id = random.getrandbits(64) + return span_id diff --git a/api/enterprise/telemetry/metric_handler.py b/api/enterprise/telemetry/metric_handler.py new file mode 100644 index 0000000000..ffd9a7e2b5 --- /dev/null +++ b/api/enterprise/telemetry/metric_handler.py @@ -0,0 +1,421 @@ +"""Enterprise metric/log event handler. + +This module processes metric and log telemetry events after they've been +dequeued from the enterprise_telemetry Celery queue. It handles case routing, +idempotency checking, and payload rehydration. +""" + +from __future__ import annotations + +import json +import logging +from datetime import UTC, datetime +from typing import Any + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from extensions.ext_redis import redis_client +from extensions.ext_storage import storage + +logger = logging.getLogger(__name__) + + +class EnterpriseMetricHandler: + """Handler for enterprise metric and log telemetry events. + + Processes envelopes from the enterprise_telemetry queue, routing each + case to the appropriate handler method. Implements idempotency checking + and payload rehydration with fallback. + """ + + def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None: + """Increment a diagnostic counter for operational monitoring. + + Args: + counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total'). + labels: Optional labels for the counter. + """ + try: + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + return + + full_counter_name = f"enterprise_telemetry.handler.{counter_name}" + logger.debug( + "Diagnostic counter: %s, labels=%s", + full_counter_name, + labels or {}, + ) + except Exception: + logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True) + + def handle(self, envelope: TelemetryEnvelope) -> None: + """Main entry point for processing telemetry envelopes. + + Args: + envelope: The telemetry envelope to process. + """ + # Check for duplicate events + if self._is_duplicate(envelope): + logger.debug( + "Skipping duplicate event: tenant_id=%s, event_id=%s", + envelope.tenant_id, + envelope.event_id, + ) + self._increment_diagnostic_counter("deduped_total") + return + + # Route to appropriate handler based on case + case = envelope.case + if case == TelemetryCase.APP_CREATED: + self._on_app_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_created"}) + elif case == TelemetryCase.APP_UPDATED: + self._on_app_updated(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_updated"}) + elif case == TelemetryCase.APP_DELETED: + self._on_app_deleted(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"}) + elif case == TelemetryCase.FEEDBACK_CREATED: + self._on_feedback_created(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"}) + elif case == TelemetryCase.MESSAGE_RUN: + self._on_message_run(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "message_run"}) + elif case == TelemetryCase.TOOL_EXECUTION: + self._on_tool_execution(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"}) + elif case == TelemetryCase.MODERATION_CHECK: + self._on_moderation_check(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"}) + elif case == TelemetryCase.SUGGESTED_QUESTION: + self._on_suggested_question(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"}) + elif case == TelemetryCase.DATASET_RETRIEVAL: + self._on_dataset_retrieval(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"}) + elif case == TelemetryCase.GENERATE_NAME: + self._on_generate_name(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "generate_name"}) + elif case == TelemetryCase.PROMPT_GENERATION: + self._on_prompt_generation(envelope) + self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"}) + else: + logger.warning( + "Unknown telemetry case: %s (tenant_id=%s, event_id=%s)", + case, + envelope.tenant_id, + envelope.event_id, + ) + + def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool: + """Check if this event has already been processed. + + Uses Redis with TTL for deduplication. Returns True if duplicate, + False if first time seeing this event. + + Args: + envelope: The telemetry envelope to check. + + Returns: + True if this event_id has been seen before, False otherwise. + """ + dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}" + + try: + # Atomic set-if-not-exists with 1h TTL + # Returns True if key was set (first time), None if already exists (duplicate) + was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600) + return was_set is None + except Exception: + # Fail open: if Redis is unavailable, process the event + # (prefer occasional duplicate over lost data) + logger.warning( + "Redis unavailable for deduplication check, processing event anyway: %s", + envelope.event_id, + exc_info=True, + ) + return False + + def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]: + """Rehydrate payload from storage reference or inline data. + + If the envelope payload is empty and metadata contains a + ``payload_ref``, the full payload is loaded from object storage + (where the gateway wrote it as JSON). When both the inline + payload and storage resolution fail, a degraded-event marker + is emitted so the gap is observable. + + Args: + envelope: The telemetry envelope containing payload data. + + Returns: + The rehydrated payload dictionary, or ``{}`` on total failure. + """ + payload = envelope.payload + + # Resolve from object storage when the gateway offloaded a large payload. + if not payload and envelope.metadata: + payload_ref = envelope.metadata.get("payload_ref") + if payload_ref: + try: + payload_bytes = storage.load(payload_ref) + payload = json.loads(payload_bytes.decode("utf-8")) + logger.debug("Loaded payload from storage: key=%s", payload_ref) + except Exception: + logger.warning( + "Failed to load payload from storage: key=%s, event_id=%s", + payload_ref, + envelope.event_id, + exc_info=True, + ) + + if not payload: + # Storage resolution failed or no data available — emit degraded event. + logger.error( + "Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s", + envelope.event_id, + envelope.tenant_id, + envelope.case, + ) + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED, + attributes={ + "tenant_id": envelope.tenant_id, + "dify.telemetry.error": f"Payload rehydration failed for event_id={envelope.event_id}", + "dify.telemetry.payload_type": envelope.case, + "dify.telemetry.correlation_id": envelope.event_id, + }, + tenant_id=envelope.tenant_id, + ) + self._increment_diagnostic_counter("rehydration_failed_total") + return {} + + return payload + + # Stub methods for each metric/log case + # These will be implemented in later tasks with actual emission logic + + def _on_app_created(self, envelope: TelemetryEnvelope) -> None: + """Handle app created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.mode": payload.get("mode"), + "dify.app.created_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_CREATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "mode": str(payload.get("mode", "")), + }, + ) + + def _on_app_updated(self, envelope: TelemetryEnvelope) -> None: + """Handle app updated event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.updated_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_UPDATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_UPDATED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None: + """Handle app deleted event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + attrs = { + "dify.app_id": payload.get("app_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app.deleted_at": datetime.now(UTC).isoformat(), + } + + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.APP_DELETED, + attributes=attrs, + tenant_id=envelope.tenant_id, + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.APP_DELETED, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + }, + ) + + def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None: + """Handle feedback created event.""" + from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent + from enterprise.telemetry.telemetry_log import emit_metric_only_event + from extensions.ext_enterprise_telemetry import get_enterprise_exporter + + exporter = get_enterprise_exporter() + if not exporter: + logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id) + return + + payload = self._rehydrate(envelope) + if not payload: + return + + include_content = exporter.include_content + attrs: dict = { + "dify.message.id": payload.get("message_id"), + "dify.tenant_id": envelope.tenant_id, + "dify.event.id": envelope.event_id, + "dify.app_id": payload.get("app_id"), + "dify.conversation.id": payload.get("conversation_id"), + "gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"), + "dify.feedback.rating": payload.get("rating"), + "dify.feedback.from_source": payload.get("from_source"), + "dify.feedback.created_at": datetime.now(UTC).isoformat(), + } + if include_content: + attrs["dify.feedback.content"] = payload.get("content") + + user_id = payload.get("from_end_user_id") or payload.get("from_account_id") + emit_metric_only_event( + event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED, + attributes=attrs, + tenant_id=envelope.tenant_id, + user_id=str(user_id or ""), + ) + exporter.increment_counter( + EnterpriseTelemetryCounter.FEEDBACK, + 1, + { + "tenant_id": envelope.tenant_id, + "app_id": str(payload.get("app_id", "")), + "rating": str(payload.get("rating", "")), + }, + ) + + def _on_message_run(self, envelope: TelemetryEnvelope) -> None: + """Handle message run event. + + Intentionally a no-op: metrics and structured logs for message runs are + emitted directly by EnterpriseOtelTrace._message_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id) + + def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None: + """Handle tool execution event. + + Intentionally a no-op: metrics and structured logs for tool executions + are emitted directly by EnterpriseOtelTrace._tool_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id) + + def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None: + """Handle moderation check event. + + Intentionally a no-op: metrics and structured logs for moderation checks + are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time, + not through the metric handler queue path. + """ + logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id) + + def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None: + """Handle suggested question event. + + Intentionally a no-op: metrics and structured logs for suggested questions + are emitted directly by EnterpriseOtelTrace._suggested_question_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id) + + def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None: + """Handle dataset retrieval event. + + Intentionally a no-op: metrics and structured logs for dataset retrievals + are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at + trace time, not through the metric handler queue path. + """ + logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id) + + def _on_generate_name(self, envelope: TelemetryEnvelope) -> None: + """Handle generate name event. + + Intentionally a no-op: metrics and structured logs for generate name + operations are emitted directly by EnterpriseOtelTrace._generate_name_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id) + + def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None: + """Handle prompt generation event. + + Intentionally a no-op: metrics and structured logs for prompt generation + operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace + at trace time, not through the metric handler queue path. + """ + logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id) diff --git a/api/enterprise/telemetry/telemetry_log.py b/api/enterprise/telemetry/telemetry_log.py new file mode 100644 index 0000000000..8cce4a9fcd --- /dev/null +++ b/api/enterprise/telemetry/telemetry_log.py @@ -0,0 +1,122 @@ +"""Structured-log emitter for enterprise telemetry events. + +Emits structured JSON log lines correlated with OTEL traces via trace_id. +Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic. +""" + +from __future__ import annotations + +import logging +import uuid +from functools import lru_cache +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from enterprise.telemetry.entities import EnterpriseTelemetryEvent + +logger = logging.getLogger("dify.telemetry") + + +@lru_cache(maxsize=4096) +def compute_trace_id_hex(uuid_str: str | None) -> str: + """Convert a business UUID string to a 32-hex OTEL-compatible trace_id. + + Returns empty string when *uuid_str* is ``None`` or invalid. + """ + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + return f"{uuid.UUID(normalized).int:032x}" + except (ValueError, AttributeError): + return "" + + +@lru_cache(maxsize=4096) +def compute_span_id_hex(uuid_str: str | None) -> str: + if not uuid_str: + return "" + normalized = uuid_str.strip().lower() + if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized): + return normalized + try: + from enterprise.telemetry.id_generator import compute_deterministic_span_id + + return f"{compute_deterministic_span_id(normalized):016x}" + except (ValueError, AttributeError): + return "" + + +def emit_telemetry_log( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + signal: str = "metric_only", + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + """Emit a structured log line for a telemetry event. + + Parameters + ---------- + event_name: + Canonical event name, e.g. ``"dify.workflow.run"``. + attributes: + All event-specific attributes (already built by the caller). + signal: + ``"metric_only"`` for events with no span, ``"span_detail"`` + for detail logs accompanying a slim span. + trace_id_source: + A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex + trace_id for cross-signal correlation. + tenant_id: + Tenant identifier (for the ``IdentityContextFilter``). + user_id: + User identifier (for the ``IdentityContextFilter``). + """ + if not logger.isEnabledFor(logging.INFO): + return + attrs = { + "dify.event.name": event_name, + "dify.event.signal": signal, + **attributes, + } + + extra: dict[str, Any] = {"attributes": attrs} + + trace_id_hex = compute_trace_id_hex(trace_id_source) + if trace_id_hex: + extra["trace_id"] = trace_id_hex + span_id_hex = compute_span_id_hex(span_id_source) + if span_id_hex: + extra["span_id"] = span_id_hex + if tenant_id: + extra["tenant_id"] = tenant_id + if user_id: + extra["user_id"] = user_id + + logger.info("telemetry.%s", signal, extra=extra) + + +def emit_metric_only_event( + *, + event_name: str | EnterpriseTelemetryEvent, + attributes: dict[str, Any], + trace_id_source: str | None = None, + span_id_source: str | None = None, + tenant_id: str | None = None, + user_id: str | None = None, +) -> None: + emit_telemetry_log( + event_name=event_name, + attributes=attributes, + signal="metric_only", + trace_id_source=trace_id_source, + span_id_source=span_id_source, + tenant_id=tenant_id, + user_id=user_id, + ) diff --git a/api/events/app_event.py b/api/events/app_event.py index f2ce71bbbb..2fba0028f9 100644 --- a/api/events/app_event.py +++ b/api/events/app_event.py @@ -11,3 +11,9 @@ app_published_workflow_was_updated = signal("app-published-workflow-was-updated" # sender: app, kwargs: synced_draft_workflow app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced") + +# sender: app +app_was_updated = signal("app-was-updated") + +# sender: app +app_was_deleted = signal("app-was-deleted") diff --git a/api/events/event_handlers/create_site_record_when_app_created.py b/api/events/event_handlers/create_site_record_when_app_created.py index 5e7caf8cbe..84be592b1a 100644 --- a/api/events/event_handlers/create_site_record_when_app_created.py +++ b/api/events/event_handlers/create_site_record_when_app_created.py @@ -1,5 +1,6 @@ from events.app_event import app_was_created from extensions.ext_database import db +from models.enums import CustomizeTokenStrategy from models.model import Site @@ -16,7 +17,7 @@ def handle(sender, **kwargs): icon=app.icon, icon_background=app.icon_background, default_language=account.interface_language, - customize_token_strategy="not_allow", + customize_token_strategy=CustomizeTokenStrategy.NOT_ALLOW, code=Site.generate_code(16), created_by=app.created_by, updated_by=app.updated_by, diff --git a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py index c43e99f0f4..7bd8e88231 100644 --- a/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py +++ b/api/events/event_handlers/delete_tool_parameters_cache_when_sync_draft_workflow.py @@ -1,9 +1,11 @@ import logging +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.tool.entities import ToolEntity + +from core.tools.entities.tool_entities import ToolProviderType from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.tool.entities import ToolEntity from events.app_event import app_draft_workflow_was_synced logger = logging.getLogger(__name__) @@ -19,8 +21,9 @@ def handle(sender, **kwargs): if node_data.get("data", {}).get("type") == BuiltinNodeTypes.TOOL: try: tool_entity = ToolEntity.model_validate(node_data["data"]) + provider_type = ToolProviderType(tool_entity.provider_type.value) tool_runtime = ToolManager.get_tool_runtime( - provider_type=tool_entity.provider_type, + provider_type=provider_type, provider_id=tool_entity.provider_id, tool_name=tool_entity.tool_name, tenant_id=app.tenant_id, @@ -30,7 +33,7 @@ def handle(sender, **kwargs): tenant_id=app.tenant_id, tool_runtime=tool_runtime, provider_name=tool_entity.provider_name, - provider_type=tool_entity.provider_type, + provider_type=provider_type, identity_id=f"WORKFLOW.{app.id}.{node_data.get('id')}", ) manager.delete_tool_parameters_cache() diff --git a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py index 20852b818e..86b5b2bbf0 100644 --- a/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py +++ b/api/events/event_handlers/update_app_dataset_join_when_app_published_workflow_updated.py @@ -1,9 +1,9 @@ from typing import cast +from graphon.nodes import BuiltinNodeTypes from sqlalchemy import delete, select from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from dify_graph.nodes import BuiltinNodeTypes from events.app_event import app_published_workflow_was_updated from extensions.ext_database import db from models.dataset import AppDatasetJoin diff --git a/api/extensions/ext_celery.py b/api/extensions/ext_celery.py index 7b6a73af52..4eed34436a 100644 --- a/api/extensions/ext_celery.py +++ b/api/extensions/ext_celery.py @@ -2,7 +2,7 @@ import ssl from datetime import timedelta from typing import Any -import pytz +import pytz # type: ignore[import-untyped] from celery import Celery, Task from celery.schedules import crontab @@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery: "schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL), } + if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED: + imports.append("tasks.enterprise_telemetry_task") celery_app.conf.update(beat_schedule=beat_schedule, imports=imports) return celery_app diff --git a/api/extensions/ext_enterprise_telemetry.py b/api/extensions/ext_enterprise_telemetry.py new file mode 100644 index 0000000000..b3cfa01aee --- /dev/null +++ b/api/extensions/ext_enterprise_telemetry.py @@ -0,0 +1,50 @@ +"""Flask extension for enterprise telemetry lifecycle management. + +Initializes the EnterpriseExporter singleton during ``create_app()`` +(single-threaded), registers blinker event handlers, and hooks atexit +for graceful shutdown. + +Skipped entirely when either ``ENTERPRISE_ENABLED`` or ``ENTERPRISE_TELEMETRY_ENABLED`` +is false (``is_enabled()`` gate). +""" + +from __future__ import annotations + +import atexit +import logging +from typing import TYPE_CHECKING + +from configs import dify_config + +if TYPE_CHECKING: + from dify_app import DifyApp + from enterprise.telemetry.exporter import EnterpriseExporter + +logger = logging.getLogger(__name__) + +_exporter: EnterpriseExporter | None = None + + +def is_enabled() -> bool: + return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED) + + +def init_app(app: DifyApp) -> None: + global _exporter + + if not is_enabled(): + return + + from enterprise.telemetry.exporter import EnterpriseExporter + + _exporter = EnterpriseExporter(dify_config) + atexit.register(_exporter.shutdown) + + # Import to trigger @signal.connect decorator registration + import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport] + + logger.info("Enterprise telemetry initialized") + + +def get_enterprise_exporter() -> EnterpriseExporter | None: + return _exporter diff --git a/api/extensions/ext_otel.py b/api/extensions/ext_otel.py index a5baa21018..63edbe93e7 100644 --- a/api/extensions/ext_otel.py +++ b/api/extensions/ext_otel.py @@ -78,16 +78,24 @@ def init_app(app: DifyApp): protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower() if dify_config.OTEL_EXPORTER_TYPE == "otlp": if protocol == "grpc": + # Auto-detect TLS: https:// uses secure, everything else is insecure + endpoint = dify_config.OTLP_BASE_ENDPOINT + insecure = not endpoint.startswith("https://") + + # Header field names must consist of lowercase letters, check RFC7540 + grpc_headers = ( + (("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else () + ) + exporter = GRPCSpanExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - # Header field names must consist of lowercase letters, check RFC7540 - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) metric_exporter = GRPCMetricExporter( - endpoint=dify_config.OTLP_BASE_ENDPOINT, - headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),), - insecure=True, + endpoint=endpoint, + headers=grpc_headers, + insecure=insecure, ) else: headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None diff --git a/api/extensions/ext_sentry.py b/api/extensions/ext_sentry.py index 9a34acb0c1..651f8ed898 100644 --- a/api/extensions/ext_sentry.py +++ b/api/extensions/ext_sentry.py @@ -5,13 +5,12 @@ from dify_app import DifyApp def init_app(app: DifyApp): if dify_config.SENTRY_DSN: import sentry_sdk + from graphon.model_runtime.errors.invoke import InvokeRateLimitError from langfuse import parse_error from sentry_sdk.integrations.celery import CeleryIntegration from sentry_sdk.integrations.flask import FlaskIntegration from werkzeug.exceptions import HTTPException - from dify_graph.model_runtime.errors.invoke import InvokeRateLimitError - def before_send(event, hint): if "exc_info" in hint: _, exc_value, _ = hint["exc_info"] diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py index a94d75ec76..db599c5d49 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_node_execution_repository.py @@ -11,9 +11,9 @@ from collections.abc import Sequence from datetime import datetime from typing import Any +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy.orm import sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value @@ -60,7 +60,7 @@ def _dict_to_workflow_node_execution_model(data: dict[str, Any]) -> WorkflowNode model.triggered_from = WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN model.node_id = data.get("node_id") or "" model.node_type = data.get("node_type") or "" - model.status = data.get("status") or "running" # Default status if missing + model.status = WorkflowNodeExecutionStatus(data.get("status") or "running") model.title = data.get("title") or "" created_by_role_val = data.get("created_by_role") try: diff --git a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py index bdfc81bd1c..3c83ab4f84 100644 --- a/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py +++ b/api/extensions/logstore/repositories/logstore_api_workflow_run_repository.py @@ -20,9 +20,9 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, cast +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import sessionmaker -from dify_graph.enums import WorkflowExecutionStatus from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value, escape_sql_string diff --git a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py index c58aa6adbb..f71b2fa1df 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_execution_repository.py @@ -4,13 +4,13 @@ import os import time from typing import Union +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker +from core.repositories.factory import WorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities import WorkflowExecution -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from extensions.logstore.aliyun_logstore import AliyunLogStore from libs.helper import extract_tenant_id from models import ( diff --git a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py index d84c0bc432..b725436681 100644 --- a/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py +++ b/api/extensions/logstore/repositories/logstore_workflow_node_execution_repository.py @@ -13,15 +13,15 @@ from collections.abc import Sequence from datetime import datetime from typing import Any, Union +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig, WorkflowNodeExecutionRepository -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter +from core.repositories.factory import OrderConfig, WorkflowNodeExecutionRepository from extensions.logstore.aliyun_logstore import AliyunLogStore from extensions.logstore.repositories import safe_float, safe_int from extensions.logstore.sql_escape import escape_identifier @@ -304,35 +304,39 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): logger.exception("Failed to dual-write node execution data to SQL database: id=%s", execution.id) # Don't raise - LogStore write succeeded, SQL is just a backup - def get_by_workflow_run( + def get_by_workflow_execution( self, - workflow_run_id: str, + workflow_execution_id: str, order_config: OrderConfig | None = None, ) -> Sequence[WorkflowNodeExecution]: """ - Retrieve all NodeExecution instances for a specific workflow run. + Retrieve all node executions for a workflow execution. Uses LogStore SQL query with window function to get the latest version of each node execution. This ensures we only get the most recent version of each node execution record. Args: - workflow_run_id: The workflow run ID + workflow_execution_id: The workflow execution identifier order_config: Optional configuration for ordering results order_config.order_by: List of fields to order by (e.g., ["index", "created_at"]) order_config.order_direction: Direction to order ("asc" or "desc") Returns: - A list of NodeExecution instances + A list of workflow node execution instances Note: This method uses ROW_NUMBER() window function partitioned by node_execution_id to get the latest version (highest log_version) of each node execution. """ - logger.debug("get_by_workflow_run: workflow_run_id=%s, order_config=%s", workflow_run_id, order_config) + logger.debug( + "get_by_workflow_execution: workflow_execution_id=%s, order_config=%s", + workflow_execution_id, + order_config, + ) # Build SQL query with deduplication using window function # ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) # ensures we get the latest version of each node execution # Escape parameters to prevent SQL injection - escaped_workflow_run_id = escape_identifier(workflow_run_id) + escaped_workflow_execution_id = escape_identifier(workflow_execution_id) escaped_tenant_id = escape_identifier(self._tenant_id) # Build ORDER BY clause for outer query @@ -360,7 +364,7 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): SELECT * FROM ( SELECT *, ROW_NUMBER() OVER (PARTITION BY node_execution_id ORDER BY log_version DESC) AS rn FROM {AliyunLogStore.workflow_node_execution_logstore} - WHERE workflow_run_id='{escaped_workflow_run_id}' + WHERE workflow_run_id='{escaped_workflow_execution_id}' AND tenant_id='{escaped_tenant_id}' {app_id_filter} ) t @@ -391,5 +395,8 @@ class LogstoreWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository): return executions except Exception: - logger.exception("Failed to retrieve node executions from LogStore: workflow_run_id=%s", workflow_run_id) + logger.exception( + "Failed to retrieve node executions from LogStore: workflow_execution_id=%s", + workflow_execution_id, + ) raise diff --git a/api/extensions/otel/parser/__init__.py b/api/extensions/otel/parser/__init__.py index 164db7c275..c671e8b409 100644 --- a/api/extensions/otel/parser/__init__.py +++ b/api/extensions/otel/parser/__init__.py @@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set OpenTelemetry span attributes according to semantic conventions. """ -from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps +from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content from extensions.otel.parser.llm import LLMNodeOTelParser from extensions.otel.parser.retrieval import RetrievalNodeOTelParser from extensions.otel.parser.tool import ToolNodeOTelParser @@ -17,4 +17,5 @@ __all__ = [ "RetrievalNodeOTelParser", "ToolNodeOTelParser", "safe_json_dumps", + "should_include_content", ] diff --git a/api/extensions/otel/parser/base.py b/api/extensions/otel/parser/base.py index 544ef3fe18..23d324f9ea 100644 --- a/api/extensions/otel/parser/base.py +++ b/api/extensions/otel/parser/base.py @@ -1,22 +1,38 @@ """ Base parser interface and utilities for OpenTelemetry node parsers. + +Content gating: ``should_include_content()`` controls whether content-bearing +span attributes (inputs, outputs, prompts, completions, documents) are written. +Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when +``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged. """ import json from typing import Any, Protocol +from graphon.enums import BuiltinNodeTypes +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span from opentelemetry.trace.status import Status, StatusCode from pydantic import BaseModel -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.file.models import File -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.variables import Segment +from configs import dify_config from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes +def should_include_content() -> bool: + """Return True if content should be written to spans. + + CE (ENTERPRISE_ENABLED=False): always True — no behaviour change. + """ + if not dify_config.ENTERPRISE_ENABLED: + return True + return dify_config.ENTERPRISE_INCLUDE_CONTENT + + def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str: """ Safely serialize objects to JSON, handling non-serializable types. @@ -101,10 +117,11 @@ class DefaultNodeOTelParser: # Extract inputs and outputs from result_event if result_event and result_event.node_run_result: node_run_result = result_event.node_run_result - if node_run_result.inputs: - span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) - if node_run_result.outputs: - span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) + if should_include_content(): + if node_run_result.inputs: + span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs)) + if node_run_result.outputs: + span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs)) if error: span.record_exception(error) diff --git a/api/extensions/otel/parser/llm.py b/api/extensions/otel/parser/llm.py index 3da9a9e97d..335c5cc29e 100644 --- a/api/extensions/otel/parser/llm.py +++ b/api/extensions/otel/parser/llm.py @@ -6,10 +6,10 @@ import logging from collections.abc import Mapping from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node from opentelemetry.trace import Span -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import LLMAttributes diff --git a/api/extensions/otel/parser/retrieval.py b/api/extensions/otel/parser/retrieval.py index dd658b250b..6df5f62c15 100644 --- a/api/extensions/otel/parser/retrieval.py +++ b/api/extensions/otel/parser/retrieval.py @@ -6,11 +6,11 @@ import logging from collections.abc import Sequence from typing import Any +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.variables import Segment from opentelemetry.trace import Span -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.variables import Segment from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import RetrieverAttributes diff --git a/api/extensions/otel/parser/tool.py b/api/extensions/otel/parser/tool.py index f4e6a18b4d..b9fdd9e1ca 100644 --- a/api/extensions/otel/parser/tool.py +++ b/api/extensions/otel/parser/tool.py @@ -2,12 +2,12 @@ Parser for tool nodes that captures tool-specific metadata. """ +from graphon.enums import WorkflowNodeExecutionMetadataKey +from graphon.graph_events import GraphNodeEventBase +from graphon.nodes.base.node import Node +from graphon.nodes.tool.entities import ToolNodeData from opentelemetry.trace import Span -from dify_graph.enums import WorkflowNodeExecutionMetadataKey -from dify_graph.graph_events import GraphNodeEventBase -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.tool.entities import ToolNodeData from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps from extensions.otel.semconv.gen_ai import ToolAttributes diff --git a/api/extensions/otel/semconv/dify.py b/api/extensions/otel/semconv/dify.py index a20b9b358d..301ddd11aa 100644 --- a/api/extensions/otel/semconv/dify.py +++ b/api/extensions/otel/semconv/dify.py @@ -21,3 +21,15 @@ class DifySpanAttributes: INVOKE_FROM = "dify.invoke_from" """Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER.""" + + INVOKED_BY = "dify.invoked_by" + """Invoked by, e.g. end_user, account, user.""" + + USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens" + """Number of input tokens (prompt tokens) used.""" + + USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens" + """Number of output tokens (completion tokens) generated.""" + + USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens" + """Total number of tokens used.""" diff --git a/api/factories/file_factory.py b/api/factories/file_factory.py deleted file mode 100644 index cb07ba58ae..0000000000 --- a/api/factories/file_factory.py +++ /dev/null @@ -1,618 +0,0 @@ -import logging -import mimetypes -import os -import re -import urllib.parse -import uuid -from collections.abc import Callable, Mapping, Sequence -from typing import Any - -import httpx -from sqlalchemy import select -from sqlalchemy.orm import Session -from werkzeug.http import parse_options_header - -from constants import AUDIO_EXTENSIONS, DOCUMENT_EXTENSIONS, IMAGE_EXTENSIONS, VIDEO_EXTENSIONS -from core.helper import ssrf_proxy -from dify_graph.file import File, FileBelongsTo, FileTransferMethod, FileType, FileUploadConfig, helpers -from extensions.ext_database import db -from models import MessageFile, ToolFile, UploadFile - -logger = logging.getLogger(__name__) - - -def build_from_message_files( - *, - message_files: Sequence["MessageFile"], - tenant_id: str, - config: FileUploadConfig | None = None, -) -> Sequence[File]: - results = [ - build_from_message_file(message_file=file, tenant_id=tenant_id, config=config) - for file in message_files - if file.belongs_to != FileBelongsTo.ASSISTANT - ] - return results - - -def build_from_message_file( - *, - message_file: "MessageFile", - tenant_id: str, - config: FileUploadConfig | None, -): - mapping = { - "transfer_method": message_file.transfer_method, - "url": message_file.url, - "type": message_file.type, - } - - # Only include id if it exists (message_file has been committed to DB) - if message_file.id: - mapping["id"] = message_file.id - - # Set the correct ID field based on transfer method - if message_file.transfer_method == FileTransferMethod.TOOL_FILE: - mapping["tool_file_id"] = message_file.upload_file_id - else: - mapping["upload_file_id"] = message_file.upload_file_id - - return build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - ) - - -def build_from_mapping( - *, - mapping: Mapping[str, Any], - tenant_id: str, - config: FileUploadConfig | None = None, - strict_type_validation: bool = False, -) -> File: - transfer_method_value = mapping.get("transfer_method") - if not transfer_method_value: - raise ValueError("transfer_method is required in file mapping") - transfer_method = FileTransferMethod.value_of(transfer_method_value) - - build_functions: dict[FileTransferMethod, Callable] = { - FileTransferMethod.LOCAL_FILE: _build_from_local_file, - FileTransferMethod.REMOTE_URL: _build_from_remote_url, - FileTransferMethod.TOOL_FILE: _build_from_tool_file, - FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, - } - - build_func = build_functions.get(transfer_method) - if not build_func: - raise ValueError(f"Invalid file transfer method: {transfer_method}") - - file: File = build_func( - mapping=mapping, - tenant_id=tenant_id, - transfer_method=transfer_method, - strict_type_validation=strict_type_validation, - ) - - if config and not _is_file_valid_with_config( - input_file_type=mapping.get("type", FileType.CUSTOM), - file_extension=file.extension or "", - file_transfer_method=file.transfer_method, - config=config, - ): - raise ValueError(f"File validation failed for file: {file.filename}") - - return file - - -def build_from_mappings( - *, - mappings: Sequence[Mapping[str, Any]], - config: FileUploadConfig | None = None, - tenant_id: str, - strict_type_validation: bool = False, -) -> Sequence[File]: - # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. - # Implement batch processing to reduce database load when handling multiple files. - # Filter out None/empty mappings to avoid errors - def is_valid_mapping(m: Mapping[str, Any]) -> bool: - if not m or not m.get("transfer_method"): - return False - # For REMOTE_URL transfer method, ensure url or remote_url is provided and not None - transfer_method = m.get("transfer_method") - if transfer_method == FileTransferMethod.REMOTE_URL: - url = m.get("url") or m.get("remote_url") - if not url: - return False - return True - - valid_mappings = [m for m in mappings if is_valid_mapping(m)] - files = [ - build_from_mapping( - mapping=mapping, - tenant_id=tenant_id, - config=config, - strict_type_validation=strict_type_validation, - ) - for mapping in valid_mappings - ] - - if ( - config - # If image config is set. - and config.image_config - # And the number of image files exceeds the maximum limit - and sum(1 for _ in (filter(lambda x: x.type == FileType.IMAGE, files))) > config.image_config.number_limits - ): - raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") - if config and config.number_limits and len(files) > config.number_limits: - raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") - - return files - - -def _build_from_local_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if not upload_file_id: - raise ValueError("Invalid upload file id") - # check if upload_file_id is a valid uuid - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - row = db.session.scalar(stmt) - if row is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) - specified_type = mapping.get("type", "custom") - - if strict_type_validation and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=row.name, - extension="." + row.extension, - mime_type=row.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=row.source_url, - related_id=mapping.get("upload_file_id"), - size=row.size, - storage_key=row.key, - ) - - -def _build_from_remote_url( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - upload_file_id = mapping.get("upload_file_id") - if upload_file_id: - try: - uuid.UUID(upload_file_id) - except ValueError: - raise ValueError("Invalid upload file id format") - stmt = select(UploadFile).where( - UploadFile.id == upload_file_id, - UploadFile.tenant_id == tenant_id, - ) - - upload_file = db.session.scalar(stmt) - if upload_file is None: - raise ValueError("Invalid upload file") - - detected_file_type = _standardize_file_type( - extension="." + upload_file.extension, mime_type=upload_file.mime_type - ) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=upload_file.name, - extension="." + upload_file.extension, - mime_type=upload_file.mime_type, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), - related_id=mapping.get("upload_file_id"), - size=upload_file.size, - storage_key=upload_file.key, - ) - url = mapping.get("url") or mapping.get("remote_url") - if not url: - raise ValueError("Invalid file url") - - mime_type, filename, file_size = _get_remote_file_info(url) - extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") - - detected_file_type = _standardize_file_type(extension=extension, mime_type=mime_type) - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - filename=filename, - tenant_id=tenant_id, - type=file_type, - transfer_method=transfer_method, - remote_url=url, - mime_type=mime_type, - extension=extension, - size=file_size, - storage_key="", - ) - - -def _extract_filename(url_path: str, content_disposition: str | None) -> str | None: - filename: str | None = None - # Try to extract from Content-Disposition header first - if content_disposition: - # Manually extract filename* parameter since parse_options_header doesn't support it - filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) - if filename_star_match: - raw_star = filename_star_match.group(1).strip() - # Remove trailing quotes if present - raw_star = raw_star.removesuffix('"') - # format: charset'lang'value - try: - parts = raw_star.split("'", 2) - charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" - value = parts[2] if len(parts) == 3 else parts[-1] - filename = urllib.parse.unquote(value, encoding=charset, errors="replace") - except Exception: - # Fallback: try to extract value after the last single quote - if "''" in raw_star: - filename = urllib.parse.unquote(raw_star.split("''")[-1]) - else: - filename = urllib.parse.unquote(raw_star) - - if not filename: - # Fallback to regular filename parameter - _, params = parse_options_header(content_disposition) - raw = params.get("filename") - if raw: - # Strip surrounding quotes and percent-decode if present - if len(raw) >= 2 and raw[0] == raw[-1] == '"': - raw = raw[1:-1] - filename = urllib.parse.unquote(raw) - # Fallback to URL path if no filename from header - if not filename: - candidate = os.path.basename(url_path) - filename = urllib.parse.unquote(candidate) if candidate else None - # Defense-in-depth: ensure basename only - if filename: - filename = os.path.basename(filename) - # Return None if filename is empty or only whitespace - if not filename or not filename.strip(): - filename = None - return filename or None - - -def _guess_mime_type(filename: str) -> str: - """Guess MIME type from filename, returning empty string if None.""" - guessed_mime, _ = mimetypes.guess_type(filename) - return guessed_mime or "" - - -def _get_remote_file_info(url: str): - file_size = -1 - parsed_url = urllib.parse.urlparse(url) - url_path = parsed_url.path - filename = os.path.basename(url_path) - - # Initialize mime_type from filename as fallback - mime_type = _guess_mime_type(filename) - - resp = ssrf_proxy.head(url, follow_redirects=True) - if resp.status_code == httpx.codes.OK: - content_disposition = resp.headers.get("Content-Disposition") - extracted_filename = _extract_filename(url_path, content_disposition) - if extracted_filename: - filename = extracted_filename - mime_type = _guess_mime_type(filename) - file_size = int(resp.headers.get("Content-Length", file_size)) - # Fallback to Content-Type header if mime_type is still empty - if not mime_type: - mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() - - if not filename: - extension = mimetypes.guess_extension(mime_type) or ".bin" - filename = f"{uuid.uuid4().hex}{extension}" - if not mime_type: - mime_type = _guess_mime_type(filename) - - return mime_type, filename, file_size - - -def _build_from_tool_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - # Backward/interop compatibility: allow tool_file_id to come from related_id or URL - tool_file_id = mapping.get("tool_file_id") - - if not tool_file_id: - raise ValueError(f"ToolFile {tool_file_id} not found") - tool_file = db.session.scalar( - select(ToolFile).where( - ToolFile.id == tool_file_id, - ToolFile.tenant_id == tenant_id, - ) - ) - - if tool_file is None: - raise ValueError(f"ToolFile {tool_file_id} not found") - - extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" - - detected_file_type = _standardize_file_type(extension=extension, mime_type=tool_file.mimetype) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("id"), - tenant_id=tenant_id, - filename=tool_file.name, - type=file_type, - transfer_method=transfer_method, - remote_url=tool_file.original_url, - related_id=tool_file.id, - extension=extension, - mime_type=tool_file.mimetype, - size=tool_file.size, - storage_key=tool_file.file_key, - ) - - -def _build_from_datasource_file( - *, - mapping: Mapping[str, Any], - tenant_id: str, - transfer_method: FileTransferMethod, - strict_type_validation: bool = False, -) -> File: - datasource_file_id = mapping.get("datasource_file_id") - if not datasource_file_id: - raise ValueError(f"DatasourceFile {datasource_file_id} not found") - datasource_file = db.session.scalar( - select(UploadFile).where( - UploadFile.id == datasource_file_id, - UploadFile.tenant_id == tenant_id, - ) - ) - - if datasource_file is None: - raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") - - extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" - - detected_file_type = _standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) - - specified_type = mapping.get("type") - - if strict_type_validation and specified_type and detected_file_type.value != specified_type: - raise ValueError("Detected file type does not match the specified type. Please verify the file.") - - if specified_type and specified_type != "custom": - file_type = FileType(specified_type) - else: - file_type = detected_file_type - - return File( - id=mapping.get("datasource_file_id"), - tenant_id=tenant_id, - filename=datasource_file.name, - type=file_type, - transfer_method=FileTransferMethod.TOOL_FILE, - remote_url=datasource_file.source_url, - related_id=datasource_file.id, - extension=extension, - mime_type=datasource_file.mime_type, - size=datasource_file.size, - storage_key=datasource_file.key, - url=datasource_file.source_url, - ) - - -def _is_file_valid_with_config( - *, - input_file_type: str, - file_extension: str, - file_transfer_method: FileTransferMethod, - config: FileUploadConfig, -) -> bool: - # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) - # These are internally generated and should bypass user upload restrictions - if file_transfer_method == FileTransferMethod.TOOL_FILE: - return True - - if ( - config.allowed_file_types - and input_file_type not in config.allowed_file_types - and input_file_type != FileType.CUSTOM - ): - return False - - if ( - input_file_type == FileType.CUSTOM - and config.allowed_file_extensions is not None - and file_extension not in config.allowed_file_extensions - ): - return False - - if input_file_type == FileType.IMAGE: - if ( - config.image_config - and config.image_config.transfer_methods - and file_transfer_method not in config.image_config.transfer_methods - ): - return False - elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: - return False - - return True - - -def _standardize_file_type(*, extension: str = "", mime_type: str = "") -> FileType: - """ - Infer the possible actual type of the file based on the extension and mime_type - """ - guessed_type = None - if extension: - guessed_type = _get_file_type_by_extension(extension) - if guessed_type is None and mime_type: - guessed_type = _get_file_type_by_mimetype(mime_type) - return guessed_type or FileType.CUSTOM - - -def _get_file_type_by_extension(extension: str) -> FileType | None: - extension = extension.lstrip(".") - if extension in IMAGE_EXTENSIONS: - return FileType.IMAGE - elif extension in VIDEO_EXTENSIONS: - return FileType.VIDEO - elif extension in AUDIO_EXTENSIONS: - return FileType.AUDIO - elif extension in DOCUMENT_EXTENSIONS: - return FileType.DOCUMENT - return None - - -def _get_file_type_by_mimetype(mime_type: str) -> FileType | None: - if "image" in mime_type: - file_type = FileType.IMAGE - elif "video" in mime_type: - file_type = FileType.VIDEO - elif "audio" in mime_type: - file_type = FileType.AUDIO - elif "text" in mime_type or "pdf" in mime_type: - file_type = FileType.DOCUMENT - else: - file_type = FileType.CUSTOM - return file_type - - -def get_file_type_by_mime_type(mime_type: str) -> FileType: - return _get_file_type_by_mimetype(mime_type) or FileType.CUSTOM - - -class StorageKeyLoader: - """FileKeyLoader load the storage key from database for a list of files. - This loader is batched, the database query count is constant regardless of the input size. - """ - - def __init__(self, session: Session, tenant_id: str): - self._session = session - self._tenant_id = tenant_id - - def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: - stmt = select(UploadFile).where( - UploadFile.id.in_(upload_file_ids), - UploadFile.tenant_id == self._tenant_id, - ) - - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: - stmt = select(ToolFile).where( - ToolFile.id.in_(tool_file_ids), - ToolFile.tenant_id == self._tenant_id, - ) - return {uuid.UUID(i.id): i for i in self._session.scalars(stmt)} - - def load_storage_keys(self, files: Sequence[File]): - """Loads storage keys for a sequence of files by retrieving the corresponding - `UploadFile` or `ToolFile` records from the database based on their transfer method. - - This method doesn't modify the input sequence structure but updates the `_storage_key` - property of each file object by extracting the relevant key from its database record. - - Performance note: This is a batched operation where database query count remains constant - regardless of input size. However, for optimal performance, input sequences should contain - fewer than 1000 files. For larger collections, split into smaller batches and process each - batch separately. - """ - - upload_file_ids: list[uuid.UUID] = [] - tool_file_ids: list[uuid.UUID] = [] - for file in files: - related_model_id = file.related_id - if file.related_id is None: - raise ValueError("file id should not be None.") - if file.tenant_id != self._tenant_id: - err_msg = ( - f"invalid file, expected tenant_id={self._tenant_id}, " - f"got tenant_id={file.tenant_id}, file_id={file.id}, related_model_id={related_model_id}" - ) - raise ValueError(err_msg) - model_id = uuid.UUID(related_model_id) - - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_ids.append(model_id) - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_ids.append(model_id) - - tool_files = self._load_tool_files(tool_file_ids) - upload_files = self._load_upload_files(upload_file_ids) - for file in files: - model_id = uuid.UUID(file.related_id) - if file.transfer_method in (FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL): - upload_file_row = upload_files.get(model_id) - if upload_file_row is None: - raise ValueError(f"Upload file not found for id: {model_id}") - file.storage_key = upload_file_row.key - elif file.transfer_method == FileTransferMethod.TOOL_FILE: - tool_file_row = tool_files.get(model_id) - if tool_file_row is None: - raise ValueError(f"Tool file not found for id: {model_id}") - file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/__init__.py b/api/factories/file_factory/__init__.py new file mode 100644 index 0000000000..ae0cd972ec --- /dev/null +++ b/api/factories/file_factory/__init__.py @@ -0,0 +1,18 @@ +"""Workflow file factory package. + +This package normalizes workflow-layer file payloads into graph-layer ``File`` +values. It keeps tenancy and ownership checks in the application layer and +exports the workflow-facing file builders for callers. +""" + +from .builders import build_from_mapping, build_from_mappings +from .message_files import build_from_message_file, build_from_message_files +from .storage_keys import StorageKeyLoader + +__all__ = [ + "StorageKeyLoader", + "build_from_mapping", + "build_from_mappings", + "build_from_message_file", + "build_from_message_files", +] diff --git a/api/factories/file_factory/builders.py b/api/factories/file_factory/builders.py new file mode 100644 index 0000000000..7516d18c8e --- /dev/null +++ b/api/factories/file_factory/builders.py @@ -0,0 +1,328 @@ +"""Core builders for workflow file mappings.""" + +from __future__ import annotations + +import mimetypes +import uuid +from collections.abc import Mapping, Sequence +from typing import Any + +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig, helpers, standardize_file_type +from sqlalchemy import select + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference +from extensions.ext_database import db +from models import ToolFile, UploadFile + +from .common import resolve_mapping_file_id +from .remote import get_remote_file_info +from .validation import is_file_valid_with_config + + +def build_from_mapping( + *, + mapping: Mapping[str, Any], + tenant_id: str, + config: FileUploadConfig | None = None, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + transfer_method_value = mapping.get("transfer_method") + if not transfer_method_value: + raise ValueError("transfer_method is required in file mapping") + + transfer_method = FileTransferMethod.value_of(transfer_method_value) + build_func = _get_build_function(transfer_method) + file = build_func( + mapping=mapping, + tenant_id=tenant_id, + transfer_method=transfer_method, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + + if config and not is_file_valid_with_config( + input_file_type=mapping.get("type", FileType.CUSTOM), + file_extension=file.extension or "", + file_transfer_method=file.transfer_method, + config=config, + ): + raise ValueError(f"File validation failed for file: {file.filename}") + + return file + + +def build_from_mappings( + *, + mappings: Sequence[Mapping[str, Any]], + config: FileUploadConfig | None = None, + tenant_id: str, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + # TODO(QuantumGhost): Performance concern - each mapping triggers a separate database query. + # Implement batch processing to reduce database load when handling multiple files. + valid_mappings = [mapping for mapping in mappings if _is_valid_mapping(mapping)] + files = [ + build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + strict_type_validation=strict_type_validation, + access_controller=access_controller, + ) + for mapping in valid_mappings + ] + + if ( + config + and config.image_config + and sum(1 for file in files if file.type == FileType.IMAGE) > config.image_config.number_limits + ): + raise ValueError(f"Number of image files exceeds the maximum limit {config.image_config.number_limits}") + if config and config.number_limits and len(files) > config.number_limits: + raise ValueError(f"Number of files exceeds the maximum limit {config.number_limits}") + + return files + + +def _get_build_function(transfer_method: FileTransferMethod): + build_functions = { + FileTransferMethod.LOCAL_FILE: _build_from_local_file, + FileTransferMethod.REMOTE_URL: _build_from_remote_url, + FileTransferMethod.TOOL_FILE: _build_from_tool_file, + FileTransferMethod.DATASOURCE_FILE: _build_from_datasource_file, + } + build_func = build_functions.get(transfer_method) + if build_func is None: + raise ValueError(f"Invalid file transfer method: {transfer_method}") + return build_func + + +def _resolve_file_type( + *, + detected_file_type: FileType, + specified_type: str | None, + strict_type_validation: bool, +) -> FileType: + if strict_type_validation and specified_type and detected_file_type.value != specified_type: + raise ValueError("Detected file type does not match the specified type. Please verify the file.") + + if specified_type and specified_type != "custom": + return FileType(specified_type) + return detected_file_type + + +def _build_from_local_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if not upload_file_id: + raise ValueError("Invalid upload file id") + + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + row = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if row is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type(extension="." + row.extension, mime_type=row.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type", "custom"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=row.name, + extension="." + row.extension, + mime_type=row.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=row.source_url, + reference=build_file_reference(record_id=str(row.id)), + size=row.size, + storage_key=row.key, + ) + + +def _build_from_remote_url( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + upload_file_id = resolve_mapping_file_id(mapping, "upload_file_id") + if upload_file_id: + try: + uuid.UUID(upload_file_id) + except ValueError as exc: + raise ValueError("Invalid upload file id format") from exc + + stmt = select(UploadFile).where( + UploadFile.id == upload_file_id, + UploadFile.tenant_id == tenant_id, + ) + upload_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if upload_file is None: + raise ValueError("Invalid upload file") + + detected_file_type = standardize_file_type( + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + ) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=upload_file.name, + extension="." + upload_file.extension, + mime_type=upload_file.mime_type, + type=file_type, + transfer_method=transfer_method, + remote_url=helpers.get_signed_file_url(upload_file_id=str(upload_file_id)), + reference=build_file_reference(record_id=str(upload_file.id)), + size=upload_file.size, + storage_key=upload_file.key, + ) + + url = mapping.get("url") or mapping.get("remote_url") + if not url: + raise ValueError("Invalid file url") + + mime_type, filename, file_size = get_remote_file_info(url) + extension = mimetypes.guess_extension(mime_type) or ("." + filename.split(".")[-1] if "." in filename else ".bin") + detected_file_type = standardize_file_type(extension=extension, mime_type=mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=filename, + type=file_type, + transfer_method=transfer_method, + remote_url=url, + mime_type=mime_type, + extension=extension, + size=file_size, + ) + + +def _build_from_tool_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + tool_file_id = resolve_mapping_file_id(mapping, "tool_file_id") + if not tool_file_id: + raise ValueError(f"ToolFile {tool_file_id} not found") + + stmt = select(ToolFile).where( + ToolFile.id == tool_file_id, + ToolFile.tenant_id == tenant_id, + ) + tool_file = db.session.scalar(access_controller.apply_tool_file_filters(stmt)) + if tool_file is None: + raise ValueError(f"ToolFile {tool_file_id} not found") + + extension = "." + tool_file.file_key.split(".")[-1] if "." in tool_file.file_key else ".bin" + detected_file_type = standardize_file_type(extension=extension, mime_type=tool_file.mimetype) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("id"), + filename=tool_file.name, + type=file_type, + transfer_method=transfer_method, + remote_url=tool_file.original_url, + reference=build_file_reference(record_id=str(tool_file.id)), + extension=extension, + mime_type=tool_file.mimetype, + size=tool_file.size, + storage_key=tool_file.file_key, + ) + + +def _build_from_datasource_file( + *, + mapping: Mapping[str, Any], + tenant_id: str, + transfer_method: FileTransferMethod, + strict_type_validation: bool = False, + access_controller: FileAccessControllerProtocol, +) -> File: + datasource_file_id = resolve_mapping_file_id(mapping, "datasource_file_id") + if not datasource_file_id: + raise ValueError(f"DatasourceFile {datasource_file_id} not found") + + stmt = select(UploadFile).where( + UploadFile.id == datasource_file_id, + UploadFile.tenant_id == tenant_id, + ) + datasource_file = db.session.scalar(access_controller.apply_upload_file_filters(stmt)) + if datasource_file is None: + raise ValueError(f"DatasourceFile {mapping.get('datasource_file_id')} not found") + + extension = "." + datasource_file.key.split(".")[-1] if "." in datasource_file.key else ".bin" + detected_file_type = standardize_file_type(extension="." + extension, mime_type=datasource_file.mime_type) + file_type = _resolve_file_type( + detected_file_type=detected_file_type, + specified_type=mapping.get("type"), + strict_type_validation=strict_type_validation, + ) + + return File( + id=mapping.get("datasource_file_id"), + filename=datasource_file.name, + type=file_type, + transfer_method=FileTransferMethod.TOOL_FILE, + remote_url=datasource_file.source_url, + reference=build_file_reference(record_id=str(datasource_file.id)), + extension=extension, + mime_type=datasource_file.mime_type, + size=datasource_file.size, + storage_key=datasource_file.key, + url=datasource_file.source_url, + ) + + +def _is_valid_mapping(mapping: Mapping[str, Any]) -> bool: + if not mapping or not mapping.get("transfer_method"): + return False + + if mapping.get("transfer_method") == FileTransferMethod.REMOTE_URL: + url = mapping.get("url") or mapping.get("remote_url") + if not url: + return False + + return True diff --git a/api/factories/file_factory/common.py b/api/factories/file_factory/common.py new file mode 100644 index 0000000000..2e1c95ab3f --- /dev/null +++ b/api/factories/file_factory/common.py @@ -0,0 +1,27 @@ +"""Shared helpers for workflow file factory modules.""" + +from __future__ import annotations + +from collections.abc import Mapping +from typing import Any + +from core.workflow.file_reference import resolve_file_record_id + + +def resolve_mapping_file_id(mapping: Mapping[str, Any], *keys: str) -> str | None: + """Resolve historical file identifiers from persisted mapping payloads. + + Workflow and model payloads can outlive file schema changes. Older rows may + still carry concrete identifiers in legacy fields such as ``related_id``, + while newer payloads use opaque references. Keep this compatibility lookup in + the factory layer so historical data remains readable without reintroducing + storage details into graph-layer ``File`` values. + """ + + for key in (*keys, "reference", "related_id"): + raw_value = mapping.get(key) + if isinstance(raw_value, str) and raw_value: + resolved_value = resolve_file_record_id(raw_value) + if resolved_value: + return resolved_value + return None diff --git a/api/factories/file_factory/message_files.py b/api/factories/file_factory/message_files.py new file mode 100644 index 0000000000..5582b85c95 --- /dev/null +++ b/api/factories/file_factory/message_files.py @@ -0,0 +1,60 @@ +"""Adapters from persisted message files to graph-layer file values.""" + +from __future__ import annotations + +from collections.abc import Sequence + +from graphon.file import File, FileBelongsTo, FileTransferMethod, FileUploadConfig + +from core.app.file_access import FileAccessControllerProtocol +from models import MessageFile + +from .builders import build_from_mapping + + +def build_from_message_files( + *, + message_files: Sequence[MessageFile], + tenant_id: str, + config: FileUploadConfig | None = None, + access_controller: FileAccessControllerProtocol, +) -> Sequence[File]: + return [ + build_from_message_file( + message_file=message_file, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) + for message_file in message_files + if message_file.belongs_to != FileBelongsTo.ASSISTANT + ] + + +def build_from_message_file( + *, + message_file: MessageFile, + tenant_id: str, + config: FileUploadConfig | None, + access_controller: FileAccessControllerProtocol, +) -> File: + mapping = { + "transfer_method": message_file.transfer_method, + "url": message_file.url, + "type": message_file.type, + } + + if message_file.id: + mapping["id"] = message_file.id + + if message_file.transfer_method == FileTransferMethod.TOOL_FILE: + mapping["tool_file_id"] = message_file.upload_file_id + else: + mapping["upload_file_id"] = message_file.upload_file_id + + return build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + config=config, + access_controller=access_controller, + ) diff --git a/api/factories/file_factory/remote.py b/api/factories/file_factory/remote.py new file mode 100644 index 0000000000..e5a7186007 --- /dev/null +++ b/api/factories/file_factory/remote.py @@ -0,0 +1,91 @@ +"""Remote file metadata helpers used by workflow file normalization. + +These helpers are part of the ``factories.file_factory`` package surface +because both workflow builders and tests rely on the same RFC5987 filename +parsing and HEAD-response normalization rules. +""" + +from __future__ import annotations + +import mimetypes +import os +import re +import urllib.parse +import uuid + +import httpx +from werkzeug.http import parse_options_header + +from core.helper import ssrf_proxy + + +def extract_filename(url_path: str, content_disposition: str | None) -> str | None: + """Extract a safe filename from Content-Disposition or the request URL path.""" + filename: str | None = None + if content_disposition: + filename_star_match = re.search(r"filename\*=([^;]+)", content_disposition) + if filename_star_match: + raw_star = filename_star_match.group(1).strip() + raw_star = raw_star.removesuffix('"') + try: + parts = raw_star.split("'", 2) + charset = (parts[0] or "utf-8").lower() if len(parts) >= 1 else "utf-8" + value = parts[2] if len(parts) == 3 else parts[-1] + filename = urllib.parse.unquote(value, encoding=charset, errors="replace") + except Exception: + if "''" in raw_star: + filename = urllib.parse.unquote(raw_star.split("''")[-1]) + else: + filename = urllib.parse.unquote(raw_star) + + if not filename: + _, params = parse_options_header(content_disposition) + raw = params.get("filename") + if raw: + if len(raw) >= 2 and raw[0] == raw[-1] == '"': + raw = raw[1:-1] + filename = urllib.parse.unquote(raw) + + if not filename: + candidate = os.path.basename(url_path) + filename = urllib.parse.unquote(candidate) if candidate else None + + if filename: + filename = os.path.basename(filename) + if not filename or not filename.strip(): + filename = None + + return filename or None + + +def _guess_mime_type(filename: str) -> str: + guessed_mime, _ = mimetypes.guess_type(filename) + return guessed_mime or "" + + +def get_remote_file_info(url: str) -> tuple[str, str, int]: + """Resolve remote file metadata with SSRF-safe HEAD probing.""" + file_size = -1 + parsed_url = urllib.parse.urlparse(url) + url_path = parsed_url.path + filename = os.path.basename(url_path) + mime_type = _guess_mime_type(filename) + + resp = ssrf_proxy.head(url, follow_redirects=True) + if resp.status_code == httpx.codes.OK: + content_disposition = resp.headers.get("Content-Disposition") + extracted_filename = extract_filename(url_path, content_disposition) + if extracted_filename: + filename = extracted_filename + mime_type = _guess_mime_type(filename) + file_size = int(resp.headers.get("Content-Length", file_size)) + if not mime_type: + mime_type = resp.headers.get("Content-Type", "").split(";")[0].strip() + + if not filename: + extension = mimetypes.guess_extension(mime_type) or ".bin" + filename = f"{uuid.uuid4().hex}{extension}" + if not mime_type: + mime_type = _guess_mime_type(filename) + + return mime_type, filename, file_size diff --git a/api/factories/file_factory/storage_keys.py b/api/factories/file_factory/storage_keys.py new file mode 100644 index 0000000000..db3a7f3015 --- /dev/null +++ b/api/factories/file_factory/storage_keys.py @@ -0,0 +1,106 @@ +"""Batched storage-key hydration for workflow files.""" + +from __future__ import annotations + +import uuid +from collections.abc import Mapping, Sequence + +from graphon.file import File, FileTransferMethod +from sqlalchemy import select +from sqlalchemy.orm import Session + +from core.app.file_access import FileAccessControllerProtocol +from core.workflow.file_reference import build_file_reference, parse_file_reference +from models import ToolFile, UploadFile + + +class StorageKeyLoader: + """Load storage keys for files with a constant number of database queries.""" + + _session: Session + _tenant_id: str + _access_controller: FileAccessControllerProtocol + + def __init__( + self, + session: Session, + tenant_id: str, + access_controller: FileAccessControllerProtocol, + ) -> None: + self._session = session + self._tenant_id = tenant_id + self._access_controller = access_controller + + def _load_upload_files(self, upload_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, UploadFile]: + stmt = select(UploadFile).where( + UploadFile.id.in_(upload_file_ids), + UploadFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_upload_file_filters(stmt) + return {uuid.UUID(upload_file.id): upload_file for upload_file in self._session.scalars(scoped_stmt)} + + def _load_tool_files(self, tool_file_ids: Sequence[uuid.UUID]) -> Mapping[uuid.UUID, ToolFile]: + stmt = select(ToolFile).where( + ToolFile.id.in_(tool_file_ids), + ToolFile.tenant_id == self._tenant_id, + ) + scoped_stmt = self._access_controller.apply_tool_file_filters(stmt) + return {uuid.UUID(tool_file.id): tool_file for tool_file in self._session.scalars(scoped_stmt)} + + def load_storage_keys(self, files: Sequence[File]) -> None: + """Hydrate storage keys by loading their backing file rows in batches. + + The sequence shape is preserved. Each file is updated in place with a + canonical record reference and storage key loaded from an authorized + database row. Tenant scoping is enforced by this loader's context + rather than by embedding tenant identity or storage paths inside + graph-layer ``File`` values. + + For best performance, prefer batches smaller than 1000 files. + """ + + upload_file_ids: list[uuid.UUID] = [] + tool_file_ids: list[uuid.UUID] = [] + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_ids.append(model_id) + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_ids.append(model_id) + + tool_files = self._load_tool_files(tool_file_ids) + upload_files = self._load_upload_files(upload_file_ids) + for file in files: + parsed_reference = parse_file_reference(file.reference) + if parsed_reference is None: + raise ValueError("file id should not be None.") + + model_id = uuid.UUID(parsed_reference.record_id) + if file.transfer_method in ( + FileTransferMethod.LOCAL_FILE, + FileTransferMethod.REMOTE_URL, + FileTransferMethod.DATASOURCE_FILE, + ): + upload_file_row = upload_files.get(model_id) + if upload_file_row is None: + raise ValueError(f"Upload file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(upload_file_row.id), + ) + file.storage_key = upload_file_row.key + elif file.transfer_method == FileTransferMethod.TOOL_FILE: + tool_file_row = tool_files.get(model_id) + if tool_file_row is None: + raise ValueError(f"Tool file not found for id: {model_id}") + file.reference = build_file_reference( + record_id=str(tool_file_row.id), + ) + file.storage_key = tool_file_row.file_key diff --git a/api/factories/file_factory/validation.py b/api/factories/file_factory/validation.py new file mode 100644 index 0000000000..4c4f6150e4 --- /dev/null +++ b/api/factories/file_factory/validation.py @@ -0,0 +1,44 @@ +"""Validation helpers for workflow file inputs.""" + +from __future__ import annotations + +from graphon.file import FileTransferMethod, FileType, FileUploadConfig + + +def is_file_valid_with_config( + *, + input_file_type: str, + file_extension: str, + file_transfer_method: FileTransferMethod, + config: FileUploadConfig, +) -> bool: + # FIXME(QIN2DIM): Always allow tool files (files generated by the assistant/model) + # These are internally generated and should bypass user upload restrictions + if file_transfer_method == FileTransferMethod.TOOL_FILE: + return True + + if ( + config.allowed_file_types + and input_file_type not in config.allowed_file_types + and input_file_type != FileType.CUSTOM + ): + return False + + if ( + input_file_type == FileType.CUSTOM + and config.allowed_file_extensions is not None + and file_extension not in config.allowed_file_extensions + ): + return False + + if input_file_type == FileType.IMAGE: + if ( + config.image_config + and config.image_config.transfer_methods + and file_transfer_method not in config.image_config.transfer_methods + ): + return False + elif config.allowed_file_upload_methods and file_transfer_method not in config.allowed_file_upload_methods: + return False + + return True diff --git a/api/factories/variable_factory.py b/api/factories/variable_factory.py index 14a56bf4a2..57205b5739 100644 --- a/api/factories/variable_factory.py +++ b/api/factories/variable_factory.py @@ -1,75 +1,52 @@ +"""Compatibility factory for non-graph variable bootstrapping. + +Graph runtime segment/variable conversions live under `graphon.variables`. +This module keeps the application-layer mapping helpers and re-exports the +shared conversion functions for legacy callers and tests. +""" + from collections.abc import Mapping, Sequence from typing import Any, cast -from uuid import uuid4 -from configs import dify_config -from dify_graph.constants import ( - CONVERSATION_VARIABLE_NODE_ID, - ENVIRONMENT_VARIABLE_NODE_ID, +from graphon.variables.exc import VariableError +from graphon.variables.factory import ( + TypeMismatchError, + UnsupportedSegmentTypeError, + build_segment, + build_segment_with_type, + segment_to_variable, ) -from dify_graph.file import File -from dify_graph.variables.exc import VariableError -from dify_graph.variables.segments import ( - ArrayAnySegment, - ArrayBooleanSegment, - ArrayFileSegment, - ArrayNumberSegment, - ArrayObjectSegment, - ArraySegment, - ArrayStringSegment, - BooleanSegment, - FileSegment, - FloatSegment, - IntegerSegment, - NoneSegment, - ObjectSegment, - Segment, - StringSegment, -) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import ( - ArrayAnyVariable, +from graphon.variables.types import SegmentType +from graphon.variables.variables import ( ArrayBooleanVariable, - ArrayFileVariable, ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, BooleanVariable, - FileVariable, FloatVariable, IntegerVariable, - NoneVariable, ObjectVariable, SecretVariable, StringVariable, VariableBase, ) +from configs import dify_config +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, +) -class UnsupportedSegmentTypeError(Exception): - pass - - -class TypeMismatchError(Exception): - pass - - -# Define the constant -SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = { - ArrayAnySegment: ArrayAnyVariable, - ArrayBooleanSegment: ArrayBooleanVariable, - ArrayFileSegment: ArrayFileVariable, - ArrayNumberSegment: ArrayNumberVariable, - ArrayObjectSegment: ArrayObjectVariable, - ArrayStringSegment: ArrayStringVariable, - BooleanSegment: BooleanVariable, - FileSegment: FileVariable, - FloatSegment: FloatVariable, - IntegerSegment: IntegerVariable, - NoneSegment: NoneVariable, - ObjectSegment: ObjectVariable, - StringSegment: StringVariable, -} +__all__ = [ + "TypeMismatchError", + "UnsupportedSegmentTypeError", + "build_conversation_variable_from_mapping", + "build_environment_variable_from_mapping", + "build_pipeline_variable_from_mapping", + "build_segment", + "build_segment_with_type", + "segment_to_variable", +] def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase: @@ -135,172 +112,3 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen if not result.selector: result = result.model_copy(update={"selector": selector}) return cast(VariableBase, result) - - -def build_segment(value: Any, /) -> Segment: - # NOTE: If you have runtime type information available, consider using the `build_segment_with_type` - # below - if value is None: - return NoneSegment() - if isinstance(value, Segment): - return value - if isinstance(value, str): - return StringSegment(value=value) - if isinstance(value, bool): - return BooleanSegment(value=value) - if isinstance(value, int): - return IntegerSegment(value=value) - if isinstance(value, float): - return FloatSegment(value=value) - if isinstance(value, dict): - return ObjectSegment(value=value) - if isinstance(value, File): - return FileSegment(value=value) - if isinstance(value, list): - items = [build_segment(item) for item in value] - types = {item.value_type for item in items} - if all(isinstance(item, ArraySegment) for item in items): - return ArrayAnySegment(value=value) - elif len(types) != 1: - if types.issubset({SegmentType.NUMBER, SegmentType.INTEGER, SegmentType.FLOAT}): - return ArrayNumberSegment(value=value) - return ArrayAnySegment(value=value) - - match types.pop(): - case SegmentType.STRING: - return ArrayStringSegment(value=value) - case SegmentType.NUMBER | SegmentType.INTEGER | SegmentType.FLOAT: - return ArrayNumberSegment(value=value) - case SegmentType.BOOLEAN: - return ArrayBooleanSegment(value=value) - case SegmentType.OBJECT: - return ArrayObjectSegment(value=value) - case SegmentType.FILE: - return ArrayFileSegment(value=value) - case SegmentType.NONE: - return ArrayAnySegment(value=value) - case _: - # This should be unreachable. - raise ValueError(f"not supported value {value}") - raise ValueError(f"not supported value {value}") - - -_segment_factory: Mapping[SegmentType, type[Segment]] = { - SegmentType.NONE: NoneSegment, - SegmentType.STRING: StringSegment, - SegmentType.INTEGER: IntegerSegment, - SegmentType.FLOAT: FloatSegment, - SegmentType.FILE: FileSegment, - SegmentType.BOOLEAN: BooleanSegment, - SegmentType.OBJECT: ObjectSegment, - # Array types - SegmentType.ARRAY_ANY: ArrayAnySegment, - SegmentType.ARRAY_STRING: ArrayStringSegment, - SegmentType.ARRAY_NUMBER: ArrayNumberSegment, - SegmentType.ARRAY_OBJECT: ArrayObjectSegment, - SegmentType.ARRAY_FILE: ArrayFileSegment, - SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment, -} - - -def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment: - """ - Build a segment with explicit type checking. - - This function creates a segment from a value while enforcing type compatibility - with the specified segment_type. It provides stricter type validation compared - to the standard build_segment function. - - Args: - segment_type: The expected SegmentType for the resulting segment - value: The value to be converted into a segment - - Returns: - Segment: A segment instance of the appropriate type - - Raises: - TypeMismatchError: If the value type doesn't match the expected segment_type - - Special Cases: - - For empty list [] values, if segment_type is array[*], returns the corresponding array type - - Type validation is performed before segment creation - - Examples: - >>> build_segment_with_type(SegmentType.STRING, "hello") - StringSegment(value="hello") - - >>> build_segment_with_type(SegmentType.ARRAY_STRING, []) - ArrayStringSegment(value=[]) - - >>> build_segment_with_type(SegmentType.STRING, 123) - # Raises TypeMismatchError - """ - # Handle None values - if value is None: - if segment_type == SegmentType.NONE: - return NoneSegment() - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got None") - - # Handle empty list special case for array types - if isinstance(value, list) and len(value) == 0: - if segment_type == SegmentType.ARRAY_ANY: - return ArrayAnySegment(value=value) - elif segment_type == SegmentType.ARRAY_STRING: - return ArrayStringSegment(value=value) - elif segment_type == SegmentType.ARRAY_BOOLEAN: - return ArrayBooleanSegment(value=value) - elif segment_type == SegmentType.ARRAY_NUMBER: - return ArrayNumberSegment(value=value) - elif segment_type == SegmentType.ARRAY_OBJECT: - return ArrayObjectSegment(value=value) - elif segment_type == SegmentType.ARRAY_FILE: - return ArrayFileSegment(value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got empty list") - - inferred_type = SegmentType.infer_segment_type(value) - # Type compatibility checking - if inferred_type is None: - raise TypeMismatchError( - f"Type mismatch: expected {segment_type}, but got python object, type={type(value)}, value={value}" - ) - if inferred_type == segment_type: - segment_class = _segment_factory[segment_type] - return segment_class(value_type=segment_type, value=value) - elif segment_type == SegmentType.NUMBER and inferred_type in ( - SegmentType.INTEGER, - SegmentType.FLOAT, - ): - segment_class = _segment_factory[inferred_type] - return segment_class(value_type=inferred_type, value=value) - else: - raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}") - - -def segment_to_variable( - *, - segment: Segment, - selector: Sequence[str], - id: str | None = None, - name: str | None = None, - description: str = "", -) -> VariableBase: - if isinstance(segment, VariableBase): - return segment - name = name or selector[-1] - id = id or str(uuid4()) - - segment_type = type(segment) - if segment_type not in SEGMENT_TO_VARIABLE_MAP: - raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}") - - variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type] - return variable_class( - id=id, - name=name, - description=description, - value_type=segment.value_type, - value=segment.value, - selector=list(selector), - ) diff --git a/api/fields/_value_type_serializer.py b/api/fields/_value_type_serializer.py index ac7c5376fb..b5acbbbcb4 100644 --- a/api/fields/_value_type_serializer.py +++ b/api/fields/_value_type_serializer.py @@ -1,7 +1,7 @@ from typing import TypedDict -from dify_graph.variables.segments import Segment -from dify_graph.variables.types import SegmentType +from graphon.variables.segments import Segment +from graphon.variables.types import SegmentType class _VarTypedDict(TypedDict, total=False): diff --git a/api/fields/conversation_fields.py b/api/fields/conversation_fields.py index a5c7ddbb11..30d02aeedc 100644 --- a/api/fields/conversation_fields.py +++ b/api/fields/conversation_fields.py @@ -3,10 +3,9 @@ from __future__ import annotations from datetime import datetime from typing import Any, TypeAlias +from graphon.file import File from pydantic import BaseModel, ConfigDict, Field, field_validator, model_validator -from dify_graph.file import File - JSONValue: TypeAlias = Any @@ -311,7 +310,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValue) -> JSONValue: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/member_fields.py b/api/fields/member_fields.py index 7ee628726b..b8daa5af30 100644 --- a/api/fields/member_fields.py +++ b/api/fields/member_fields.py @@ -3,10 +3,9 @@ from __future__ import annotations from datetime import datetime from flask_restx import fields +from graphon.file import helpers as file_helpers from pydantic import BaseModel, ConfigDict, computed_field, field_validator -from dify_graph.file import helpers as file_helpers - simple_account_fields = { "id": fields.String, "name": fields.String, diff --git a/api/fields/message_fields.py b/api/fields/message_fields.py index 428f92ed33..d982c31aee 100644 --- a/api/fields/message_fields.py +++ b/api/fields/message_fields.py @@ -4,10 +4,10 @@ from datetime import datetime from typing import TypeAlias from uuid import uuid4 +from graphon.file import File from pydantic import BaseModel, ConfigDict, Field, field_validator from core.entities.execution_extra_content import ExecutionExtraContentDomainModel -from dify_graph.file import File from fields.conversation_fields import AgentThought, JSONValue, MessageFile JSONValueType: TypeAlias = JSONValue @@ -133,7 +133,9 @@ def to_timestamp(value: datetime | None) -> int | None: def format_files_contained(value: JSONValueType) -> JSONValueType: if isinstance(value, File): - return value.model_dump() + # Response payloads must preserve legacy file keys like `related_id`/`url` + # while still exposing the new graph-layer `reference` field. + return value.to_dict() if isinstance(value, dict): return {k: format_files_contained(v) for k, v in value.items()} if isinstance(value, list): diff --git a/api/fields/raws.py b/api/fields/raws.py index 318dedc25c..4c65cdab7a 100644 --- a/api/fields/raws.py +++ b/api/fields/raws.py @@ -1,6 +1,5 @@ from flask_restx import fields - -from dify_graph.file import File +from graphon.file import File class FilesContainedField(fields.Raw): diff --git a/api/fields/workflow_fields.py b/api/fields/workflow_fields.py index 7ce2139687..b0b6cc0b48 100644 --- a/api/fields/workflow_fields.py +++ b/api/fields/workflow_fields.py @@ -1,7 +1,7 @@ from flask_restx import fields +from graphon.variables import SecretVariable, SegmentType, VariableBase from core.helper import encrypter -from dify_graph.variables import SecretVariable, SegmentType, VariableBase from fields.member_fields import simple_account_fields from libs.helper import TimestampField diff --git a/api/libs/datetime_utils.py b/api/libs/datetime_utils.py index c08578981b..e0a6ec2cac 100644 --- a/api/libs/datetime_utils.py +++ b/api/libs/datetime_utils.py @@ -2,7 +2,7 @@ import abc import datetime from typing import Protocol -import pytz +import pytz # type: ignore[import-untyped] class _NowFunction(Protocol): diff --git a/api/libs/helper.py b/api/libs/helper.py index e7572cc025..a7b3da77ff 100644 --- a/api/libs/helper.py +++ b/api/libs/helper.py @@ -16,13 +16,13 @@ from zoneinfo import available_timezones from flask import Response, stream_with_context from flask_restx import fields +from graphon.file import helpers as file_helpers +from graphon.model_runtime.utils.encoders import jsonable_encoder from pydantic import BaseModel from pydantic.functional_validators import AfterValidator from configs import dify_config from core.app.features.rate_limiting.rate_limit import RateLimitGenerator -from dify_graph.file import helpers as file_helpers -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_redis import redis_client if TYPE_CHECKING: @@ -174,6 +174,18 @@ def normalize_uuid(value: str | UUID) -> str: raise ValueError("must be a valid UUID") from exc +def parse_uuid_str_or_none(value: str | None) -> str | None: + """ + Return None for missing/empty UUID-like values. + + Keep non-empty values unchanged to avoid changing behavior in paths that + currently pass placeholder IDs in tests/mocks. + """ + if value is None or not str(value).strip(): + return None + return str(value) + + UUIDStrOrEmpty = Annotated[str, AfterValidator(normalize_uuid)] diff --git a/api/libs/oauth.py b/api/libs/oauth.py index 1afb42304d..76e741301c 100644 --- a/api/libs/oauth.py +++ b/api/libs/oauth.py @@ -28,6 +28,7 @@ class AccessTokenResponse(TypedDict, total=False): class GitHubEmailRecord(TypedDict, total=False): email: str primary: bool + verified: bool class GitHubRawUserInfo(TypedDict): @@ -130,25 +131,51 @@ class GitHubOAuth(OAuth): response.raise_for_status() user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response)) + # Only call the /user/emails endpoint when the profile email is absent, + # i.e. the user has "Keep my email addresses private" enabled. + resolved_email = user_info.get("email") or "" + if not resolved_email: + resolved_email = self._get_email_from_emails_endpoint(headers) + + return {**user_info, "email": resolved_email} + + @staticmethod + def _get_email_from_emails_endpoint(headers: dict[str, str]) -> str: + """Fetch the best available email from GitHub's /user/emails endpoint. + + Prefers the primary email, then falls back to any verified email. + Returns an empty string when no usable email is found. + """ try: - email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers) + email_response = httpx.get(GitHubOAuth._EMAIL_INFO_URL, headers=headers) email_response.raise_for_status() - email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) - primary_email = next((email for email in email_info if email.get("primary") is True), None) + email_records = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response)) except (httpx.HTTPStatusError, ValidationError): logger.warning("Failed to retrieve email from GitHub /user/emails endpoint", exc_info=True) - primary_email = None + return "" - return {**user_info, "email": primary_email.get("email", "") if primary_email else ""} + primary = next((r for r in email_records if r.get("primary") is True), None) + if primary: + return primary.get("email", "") + + # No primary email; try any verified email as a fallback. + verified = next((r for r in email_records if r.get("verified") is True), None) + if verified: + return verified.get("email", "") + + return "" def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo: payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info) - email = payload.get("email") + email = payload.get("email") or "" if not email: - raise ValueError( - 'Dify currently not supports the "Keep my email addresses private" feature,' - " please disable it and login again" - ) + # When no email is available from the profile or /user/emails endpoint, + # fall back to GitHub's noreply address so sign-in can still proceed. + # Use only the numeric ID (not the login) so the address stays stable + # even if the user renames their GitHub account. + github_id = payload["id"] + email = f"{github_id}@users.noreply.github.com" + logger.info("GitHub user %s has no public email; using noreply address", payload["login"]) return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name") or ""), email=email) diff --git a/api/libs/schedule_utils.py b/api/libs/schedule_utils.py index 1ab5f499e9..b80a5ea722 100644 --- a/api/libs/schedule_utils.py +++ b/api/libs/schedule_utils.py @@ -1,6 +1,6 @@ from datetime import UTC, datetime -import pytz +import pytz # type: ignore[import-untyped] from croniter import croniter diff --git a/api/models/enums.py b/api/models/enums.py index cdec7b2f12..bf2e927f00 100644 --- a/api/models/enums.py +++ b/api/models/enums.py @@ -158,6 +158,15 @@ class FeedbackFromSource(StrEnum): ADMIN = "admin" +class CustomizeTokenStrategy(StrEnum): + """Site token customization strategy""" + + MUST = "must" + ALLOW = "allow" + NOT_ALLOW = "not_allow" + UUID = "uuid" + + class FeedbackRating(StrEnum): """MessageFeedback rating""" @@ -314,6 +323,13 @@ class MessageChainType(StrEnum): SYSTEM = "system" +class PromptType(StrEnum): + """Prompt configuration type""" + + SIMPLE = "simple" + ADVANCED = "advanced" + + class ProviderQuotaType(StrEnum): PAID = "paid" """hosted paid quota""" diff --git a/api/models/human_input.py b/api/models/human_input.py index 48e7fbb9ea..79c5d62f6a 100644 --- a/api/models/human_input.py +++ b/api/models/human_input.py @@ -3,14 +3,11 @@ from enum import StrEnum from typing import Annotated, Literal, Self, final import sqlalchemy as sa +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from pydantic import BaseModel, Field from sqlalchemy.orm import Mapped, mapped_column, relationship -from dify_graph.nodes.human_input.enums import ( - DeliveryMethodType, - HumanInputFormKind, - HumanInputFormStatus, -) +from core.workflow.human_input_compat import DeliveryMethodType from libs.helper import generate_string from .base import Base, DefaultFieldsMixin diff --git a/api/models/model.py b/api/models/model.py index 68ff37bcaa..066d2acdce 100644 --- a/api/models/model.py +++ b/api/models/model.py @@ -3,16 +3,20 @@ from __future__ import annotations import json import re import uuid -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence from datetime import datetime from decimal import Decimal from enum import StrEnum, auto +from functools import lru_cache from typing import TYPE_CHECKING, Any, Literal, NotRequired, cast from uuid import uuid4 import sqlalchemy as sa from flask import request from flask_login import UserMixin # type: ignore[import-untyped] +from graphon.enums import WorkflowExecutionStatus +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.file import helpers as file_helpers from sqlalchemy import BigInteger, Float, Index, PrimaryKeyConstraint, String, exists, func, select, text from sqlalchemy.orm import Mapped, Session, mapped_column from typing_extensions import TypedDict @@ -20,12 +24,10 @@ from typing_extensions import TypedDict from configs import dify_config from constants import DEFAULT_FILE_NUMBER_LIMITS from core.tools.signature import sign_tool_file -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from dify_graph.file import helpers as file_helpers from extensions.storage.storage_type import StorageType from libs.helper import generate_string # type: ignore[import-not-found] from libs.uuid_utils import uuidv7 +from models.utils.file_input_compat import build_file_from_input_mapping from .account import Account, Tenant from .base import Base, TypeBase, gen_uuidv4_string @@ -38,12 +40,14 @@ from .enums import ( ConversationFromSource, ConversationStatus, CreatorUserRole, + CustomizeTokenStrategy, FeedbackFromSource, FeedbackRating, InvokeFrom, MessageChainType, MessageFileBelongsTo, MessageStatus, + PromptType, ProviderQuotaType, TagType, ) @@ -57,6 +61,32 @@ if TYPE_CHECKING: # --- TypedDict definitions for structured dict return types --- +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def _resolve_app_tenant_id(app_id: str) -> str: + resolved_tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not resolved_tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return resolved_tenant_id + + +def _build_app_tenant_resolver(app_id: str, owner_tenant_id: str | None = None) -> Callable[[], str]: + resolved_tenant_id = owner_tenant_id + + def resolve_owner_tenant_id() -> str: + nonlocal resolved_tenant_id + if resolved_tenant_id is None: + resolved_tenant_id = _resolve_app_tenant_id(app_id) + return resolved_tenant_id + + return resolve_owner_tenant_id + + class EnabledConfig(TypedDict): enabled: bool @@ -621,8 +651,11 @@ class AppModelConfig(TypeBase): agent_mode: Mapped[str | None] = mapped_column(LongText, default=None) sensitive_word_avoidance: Mapped[str | None] = mapped_column(LongText, default=None) retriever_resource: Mapped[str | None] = mapped_column(LongText, default=None) - prompt_type: Mapped[str] = mapped_column( - String(255), nullable=False, server_default=sa.text("'simple'"), default="simple" + prompt_type: Mapped[PromptType] = mapped_column( + EnumText(PromptType, length=255), + nullable=False, + server_default=sa.text("'simple'"), + default=PromptType.SIMPLE, ) chat_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) completion_prompt_config: Mapped[str | None] = mapped_column(LongText, default=None) @@ -774,7 +807,7 @@ class AppModelConfig(TypeBase): "dataset_query_variable": self.dataset_query_variable, "pre_prompt": self.pre_prompt, "agent_mode": self.agent_mode_dict, - "prompt_type": self.prompt_type, + "prompt_type": self.prompt_type.value if isinstance(self.prompt_type, PromptType) else self.prompt_type, "chat_prompt_config": self.chat_prompt_config_dict, "completion_prompt_config": self.completion_prompt_config_dict, "dataset_configs": self.dataset_configs_dict, @@ -818,7 +851,7 @@ class AppModelConfig(TypeBase): self.retriever_resource = ( json.dumps(model_config.get("retriever_resource")) if model_config.get("retriever_resource") else None ) - self.prompt_type = model_config.get("prompt_type", "simple") + self.prompt_type = PromptType(model_config.get("prompt_type", "simple")) self.chat_prompt_config = ( json.dumps(model_config.get("chat_prompt_config")) if model_config.get("chat_prompt_config") else None ) @@ -1057,23 +1090,26 @@ class Conversation(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: stored input payloads may come from before or after the + # graph-layer file refactor. Newer rows may omit `tenant_id`, so keep tenant + # resolution at the SQLAlchemy model boundary instead of pushing ownership back + # into `graphon.file.File`. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) # Convert file mapping to File object for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1086,15 +1122,12 @@ class Conversation(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1402,21 +1435,23 @@ class Message(Base): @property def inputs(self) -> dict[str, Any]: inputs = self._inputs.copy() + # Compatibility bridge: message inputs are persisted as JSON and must remain + # readable across file payload shape changes. Do not assume `tenant_id` + # is serialized into each file mapping going forward. + tenant_resolver = _build_app_tenant_resolver( + app_id=self.app_id, + owner_tenant_id=cast(str | None, getattr(self, "_owner_tenant_id", None)), + ) for key, value in inputs.items(): - # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. - from factories import file_factory - if ( isinstance(value, dict) and cast(dict[str, Any], value).get("dify_model_identity") == FILE_MODEL_IDENTITY ): value_dict = cast(dict[str, Any], value) - if value_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - value_dict["tool_file_id"] = value_dict["related_id"] - elif value_dict["transfer_method"] in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL]: - value_dict["upload_file_id"] = value_dict["related_id"] - tenant_id = cast(str, value_dict.get("tenant_id", "")) - inputs[key] = file_factory.build_from_mapping(mapping=value_dict, tenant_id=tenant_id) + inputs[key] = build_file_from_input_mapping( + file_mapping=value_dict, + tenant_resolver=tenant_resolver, + ) elif isinstance(value, list): value_list = cast(list[Any], value) if all( @@ -1429,15 +1464,12 @@ class Message(Base): if not isinstance(item, dict): continue item_dict = cast(dict[str, Any], item) - if item_dict["transfer_method"] == FileTransferMethod.TOOL_FILE: - item_dict["tool_file_id"] = item_dict["related_id"] - elif item_dict["transfer_method"] in [ - FileTransferMethod.LOCAL_FILE, - FileTransferMethod.REMOTE_URL, - ]: - item_dict["upload_file_id"] = item_dict["related_id"] - tenant_id = cast(str, item_dict.get("tenant_id", "")) - file_list.append(file_factory.build_from_mapping(mapping=item_dict, tenant_id=tenant_id)) + file_list.append( + build_file_from_input_mapping( + file_mapping=item_dict, + tenant_resolver=tenant_resolver, + ) + ) inputs[key] = file_list return inputs @@ -1612,6 +1644,7 @@ class Message(Base): "upload_file_id": message_file.upload_file_id, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.REMOTE_URL: if message_file.url is None: @@ -1625,6 +1658,7 @@ class Message(Base): "url": message_file.url, }, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) elif message_file.transfer_method == FileTransferMethod.TOOL_FILE: if message_file.upload_file_id is None: @@ -1639,6 +1673,7 @@ class Message(Base): file = file_factory.build_from_mapping( mapping=mapping, tenant_id=current_app.tenant_id, + access_controller=_get_file_access_controller(), ) else: raise ValueError( @@ -2054,7 +2089,9 @@ class Site(Base): use_icon_as_answer_icon: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) _custom_disclaimer: Mapped[str] = mapped_column("custom_disclaimer", LongText, default="") customize_domain = mapped_column(String(255)) - customize_token_strategy: Mapped[str] = mapped_column(String(255), nullable=False) + customize_token_strategy: Mapped[CustomizeTokenStrategy] = mapped_column( + EnumText(CustomizeTokenStrategy, length=255), nullable=False + ) prompt_public: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false")) status: Mapped[AppStatus] = mapped_column( EnumText(AppStatus, length=255), nullable=False, server_default=sa.text("'normal'"), default=AppStatus.NORMAL diff --git a/api/models/tools.py b/api/models/tools.py index 63b27b9413..d8731fb8a8 100644 --- a/api/models/tools.py +++ b/api/models/tools.py @@ -11,6 +11,7 @@ from deprecated import deprecated from sqlalchemy import ForeignKey, String, func, select from sqlalchemy.orm import Mapped, mapped_column +from core.plugin.entities.plugin_daemon import CredentialType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_bundle import ApiToolBundle from core.tools.entities.tool_entities import ( @@ -109,8 +110,11 @@ class BuiltinToolProvider(TypeBase): ) is_default: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"), default=False) # credential type, e.g., "api-key", "oauth2" - credential_type: Mapped[str] = mapped_column( - String(32), nullable=False, server_default=sa.text("'api-key'"), default="api-key" + credential_type: Mapped[CredentialType] = mapped_column( + EnumText(CredentialType, length=32), + nullable=False, + server_default=sa.text("'api-key'"), + default=CredentialType.API_KEY, ) expires_at: Mapped[int] = mapped_column(sa.BigInteger, nullable=False, server_default=sa.text("-1"), default=-1) diff --git a/api/models/trigger.py b/api/models/trigger.py index 627b854060..5233a6e271 100644 --- a/api/models/trigger.py +++ b/api/models/trigger.py @@ -102,7 +102,9 @@ class TriggerSubscription(TypeBase): credentials: Mapped[TriggerCredentials] = mapped_column( sa.JSON, nullable=False, comment="Subscription credentials JSON" ) - credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key") + credential_type: Mapped[CredentialType] = mapped_column( + EnumText(CredentialType, length=50), nullable=False, comment="oauth or api_key" + ) credential_expires_at: Mapped[int] = mapped_column( Integer, default=-1, comment="OAuth token expiration timestamp, -1 for never" ) @@ -144,7 +146,7 @@ class TriggerSubscription(TypeBase): endpoint=generate_plugin_trigger_endpoint_url(self.endpoint_id), parameters=self.parameters, properties=self.properties, - credential_type=CredentialType(self.credential_type), + credential_type=self.credential_type, credentials=self.credentials, workflows_in_use=-1, ) diff --git a/api/models/utils/__init__.py b/api/models/utils/__init__.py new file mode 100644 index 0000000000..b390b8106b --- /dev/null +++ b/api/models/utils/__init__.py @@ -0,0 +1,3 @@ +from .file_input_compat import build_file_from_input_mapping + +__all__ = ["build_file_from_input_mapping"] diff --git a/api/models/utils/file_input_compat.py b/api/models/utils/file_input_compat.py new file mode 100644 index 0000000000..f71583c1cd --- /dev/null +++ b/api/models/utils/file_input_compat.py @@ -0,0 +1,117 @@ +from __future__ import annotations + +from collections.abc import Callable, Mapping +from functools import lru_cache +from typing import Any + +from graphon.file import File, FileTransferMethod + +from core.workflow.file_reference import parse_file_reference + + +@lru_cache(maxsize=1) +def _get_file_access_controller(): + from core.app.file_access import DatabaseFileAccessController + + return DatabaseFileAccessController() + + +def resolve_file_record_id(file_mapping: Mapping[str, Any]) -> str | None: + reference = file_mapping.get("reference") + if isinstance(reference, str) and reference: + parsed_reference = parse_file_reference(reference) + if parsed_reference is not None: + return parsed_reference.record_id + + related_id = file_mapping.get("related_id") + if isinstance(related_id, str) and related_id: + parsed_reference = parse_file_reference(related_id) + if parsed_reference is not None: + return parsed_reference.record_id + + return None + + +def resolve_file_mapping_tenant_id( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> str: + tenant_id = file_mapping.get("tenant_id") + if isinstance(tenant_id, str) and tenant_id: + return tenant_id + + return tenant_resolver() + + +def build_file_from_stored_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_id: str, +) -> File: + """ + Canonicalize a persisted file payload against the current tenant context. + + Stored JSON rows can outlive file schema changes, so rebuild storage-backed + files through the workflow factory instead of trusting serialized metadata. + Pure external ``REMOTE_URL`` payloads without a backing upload row are + passed through because there is no server-owned record to rebind. + """ + + # NOTE: It's not the best way to implement this, but it's the only way to avoid circular import for now. + from factories import file_factory + + mapping = dict(file_mapping) + mapping.pop("tenant_id", None) + record_id = resolve_file_record_id(mapping) + transfer_method = FileTransferMethod.value_of(mapping["transfer_method"]) + + if transfer_method == FileTransferMethod.TOOL_FILE and record_id: + mapping["tool_file_id"] = record_id + elif transfer_method in [FileTransferMethod.LOCAL_FILE, FileTransferMethod.REMOTE_URL] and record_id: + mapping["upload_file_id"] = record_id + elif transfer_method == FileTransferMethod.DATASOURCE_FILE and record_id: + mapping["datasource_file_id"] = record_id + + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + remote_url = mapping.get("remote_url") + if not isinstance(remote_url, str) or not remote_url: + url = mapping.get("url") + if isinstance(url, str) and url: + mapping["remote_url"] = url + return File.model_validate(mapping) + + return file_factory.build_from_mapping( + mapping=mapping, + tenant_id=tenant_id, + access_controller=_get_file_access_controller(), + ) + + +def build_file_from_input_mapping( + *, + file_mapping: Mapping[str, Any], + tenant_resolver: Callable[[], str], +) -> File: + """ + Rehydrate persisted model input payloads into graph `File` objects. + + This compatibility layer exists because model JSON rows can outlive file payload + schema changes. Legacy rows may carry `related_id` and `tenant_id`, while newer + rows may only carry `reference`. Keep ownership resolution here, at the model + boundary, instead of pushing tenant data back into `graphon.file.File`. + """ + + transfer_method = FileTransferMethod.value_of(file_mapping["transfer_method"]) + record_id = resolve_file_record_id(file_mapping) + if transfer_method == FileTransferMethod.REMOTE_URL and record_id is None: + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id="", + ) + + tenant_id = resolve_file_mapping_tenant_id(file_mapping=file_mapping, tenant_resolver=tenant_resolver) + return build_file_from_stored_mapping( + file_mapping=file_mapping, + tenant_id=tenant_id, + ) diff --git a/api/models/workflow.py b/api/models/workflow.py index 334ec42058..f8868cb73c 100644 --- a/api/models/workflow.py +++ b/api/models/workflow.py @@ -8,6 +8,19 @@ from typing import TYPE_CHECKING, Any, Optional, TypedDict, Union, cast from uuid import uuid4 import sqlalchemy as sa +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import ( + BuiltinNodeTypes, + NodeType, + WorkflowExecutionStatus, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.file import File +from graphon.file.constants import maybe_file_object +from graphon.variables import utils as variable_utils +from graphon.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from sqlalchemy import ( DateTime, Index, @@ -24,17 +37,11 @@ from sqlalchemy.orm import Mapped, mapped_column from typing_extensions import deprecated from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE -from dify_graph.constants import ( +from core.workflow.human_input_compat import normalize_node_config_for_graph +from core.workflow.variable_prefixes import ( CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID, ) -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey -from dify_graph.file.constants import maybe_file_object -from dify_graph.file.models import File -from dify_graph.variables import utils as variable_utils -from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable from extensions.ext_storage import Storage from factories.variable_factory import TypeMismatchError, build_segment_with_type from libs.datetime_utils import naive_utc_now @@ -46,9 +53,10 @@ if TYPE_CHECKING: from .model import AppMode, UploadFile +from graphon.variables import SecretVariable, Segment, SegmentType, VariableBase + from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE from core.helper import encrypter -from dify_graph.variables import SecretVariable, Segment, SegmentType, VariableBase from factories import variable_factory from libs import helper @@ -57,6 +65,7 @@ from .base import Base, DefaultFieldsMixin, TypeBase from .engine import db from .enums import CreatorUserRole, DraftVariableType, ExecutionOffLoadType, WorkflowRunTriggeredFrom from .types import EnumText, LongText, StringUUID +from .utils.file_input_compat import build_file_from_stored_mapping logger = logging.getLogger(__name__) @@ -64,6 +73,15 @@ SerializedWorkflowValue = dict[str, Any] SerializedWorkflowVariables = dict[str, SerializedWorkflowValue] +def _resolve_workflow_app_tenant_id(app_id: str) -> str: + from .model import App + + tenant_id = db.session.scalar(select(App.tenant_id).where(App.id == app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {app_id}") + return tenant_id + + class WorkflowContentDict(TypedDict): graph: Mapping[str, Any] features: dict[str, Any] @@ -273,7 +291,7 @@ class Workflow(Base): # bug node_config: dict[str, Any] = next(filter(lambda node: node["id"] == node_id, nodes)) except StopIteration: raise NodeNotFoundError(node_id) - return NodeConfigDictAdapter.validate_python(node_config) + return NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config)) @staticmethod def get_node_type_from_node_config(node_config: NodeConfigDict) -> NodeType: @@ -418,7 +436,7 @@ class Workflow(Base): # bug "selected": false, } - For specific node type, refer to `dify_graph.nodes` + For specific node type, refer to `graphon.nodes` """ graph_dict = self.graph_dict if "nodes" not in graph_dict: @@ -930,7 +948,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo inputs: Mapped[str | None] = mapped_column(LongText) process_data: Mapped[str | None] = mapped_column(LongText) outputs: Mapped[str | None] = mapped_column(LongText) - status: Mapped[str] = mapped_column(String(255)) + status: Mapped[WorkflowNodeExecutionStatus] = mapped_column(EnumText(WorkflowNodeExecutionStatus, length=255)) error: Mapped[str | None] = mapped_column(LongText) elapsed_time: Mapped[float] = mapped_column(sa.Float, server_default=sa.text("0")) execution_metadata: Mapped[str | None] = mapped_column(LongText) @@ -1449,8 +1467,6 @@ class WorkflowDraftVariable(Base): # From `VARIABLE_PATTERN`, we may conclude that the length of a top level variable is less than # 80 chars. - # - # ref: api/dify_graph/entities/variable_pool.py:18 name: Mapped[str] = mapped_column(sa.String(255), nullable=False) description: Mapped[str] = mapped_column( sa.String(255), @@ -1565,10 +1581,9 @@ class WorkflowDraftVariable(Base): def _loads_value(self) -> Segment: value = json.loads(self.value) - return self.build_segment_with_type(self.value_type, value) + return self.build_segment_from_serialized_value(self.value_type, value) - @staticmethod - def rebuild_file_types(value: Any): + def _rebuild_file_types(self, value: Any): # NOTE(QuantumGhost): Temporary workaround for structured data handling. # By this point, `output` has been converted to dict by # `WorkflowEntry.handle_special_values`, so we need to @@ -1582,13 +1597,72 @@ class WorkflowDraftVariable(Base): if isinstance(value, dict): if not maybe_file_object(value): return cast(Any, value) - return File.model_validate(value) + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + return build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], value), + tenant_id=tenant_id, + ) elif isinstance(value, list) and value: value_list = cast(list[Any], value) first: Any = value_list[0] if not maybe_file_object(first): return cast(Any, value) - file_list: list[File] = [File.model_validate(cast(dict[str, Any], i)) for i in value_list] + tenant_id = _resolve_workflow_app_tenant_id(self.app_id) + file_list: list[File] = [] + for item in value_list: + file_list.append( + build_file_from_stored_mapping( + file_mapping=cast(dict[str, Any], item), + tenant_id=tenant_id, + ) + ) + return cast(Any, file_list) + else: + return cast(Any, value) + + def build_segment_from_serialized_value(self, segment_type: SegmentType, value: Any) -> Segment: + # Persisted draft variable rows may contain historical file payloads. + # Rebuild them through the file factory so tenant ownership, signed URLs, + # and storage-backed metadata come from canonical records instead of the + # serialized JSON blob. + if segment_type == SegmentType.FILE: + if isinstance(value, File): + return build_segment_with_type(segment_type, value) + elif isinstance(value, dict): + file = self._rebuild_file_types(value) + return build_segment_with_type(segment_type, file) + else: + raise TypeMismatchError(f"expected dict or File for FileSegment, got {type(value)}") + if segment_type == SegmentType.ARRAY_FILE: + if not isinstance(value, list): + raise TypeMismatchError(f"expected list for ArrayFileSegment, got {type(value)}") + file_list = self._rebuild_file_types(value) + return build_segment_with_type(segment_type=segment_type, value=file_list) + + return build_segment_with_type(segment_type=segment_type, value=value) + + @staticmethod + def rebuild_file_types(value: Any): + # Keep the class-level fallback for callers that only need lightweight + # structural reconstruction. Persisted draft-variable payloads should go + # through `build_segment_from_serialized_value()` so file metadata is + # rebuilt from canonical storage records. + if isinstance(value, dict): + if not maybe_file_object(value): + return cast(Any, value) + normalized_file = dict(value) + normalized_file.pop("tenant_id", None) + return File.model_validate(normalized_file) + elif isinstance(value, list) and value: + value_list = cast(list[Any], value) + first: Any = value_list[0] + if not maybe_file_object(first): + return cast(Any, value) + file_list: list[File] = [] + for item in value_list: + normalized_file = dict(cast(dict[str, Any], item)) + normalized_file.pop("tenant_id", None) + file_list.append(File.model_validate(normalized_file)) return cast(Any, file_list) else: return cast(Any, value) diff --git a/api/pyproject.toml b/api/pyproject.toml index 6ef98068e6..f737d0699f 100644 --- a/api/pyproject.toml +++ b/api/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "dify-api" -version = "1.13.2" +version = "1.13.3" requires-python = ">=3.11,<3.13" dependencies = [ @@ -8,7 +8,7 @@ dependencies = [ "arize-phoenix-otel~=0.15.0", "azure-identity==1.25.3", "beautifulsoup4==4.14.3", - "boto3==1.42.73", + "boto3==1.42.78", "bs4~=0.0.1", "cachetools~=5.3.0", "celery~=5.6.2", @@ -28,11 +28,11 @@ dependencies = [ "google-auth-httplib2==0.3.0", "google-cloud-aiplatform>=1.123.0", "googleapis-common-protos>=1.65.0", - "gunicorn~=25.1.0", + "graphon>=0.1.2", + "gunicorn~=25.3.0", "httpx[socks]~=0.28.0", "jieba==0.42.1", "json-repair>=0.55.1", - "jsonschema>=4.25.1", "langfuse~=2.51.3", "langsmith~=0.7.16", "markdown~=3.10.2", @@ -63,7 +63,6 @@ dependencies = [ "psycopg2-binary~=2.9.6", "pycryptodome==3.23.0", "pydantic~=2.12.5", - "pydantic-extra-types~=2.11.0", "pydantic-settings~=2.13.1", "pyjwt~=2.12.0", "pypdfium2==5.6.0", @@ -71,7 +70,7 @@ dependencies = [ "python-dotenv==1.2.2", "pyyaml~=6.0.1", "readabilipy~=0.3.0", - "redis[hiredis]~=7.3.0", + "redis[hiredis]~=7.4.0", "resend~=2.26.0", "sentry-sdk[flask]~=2.55.0", "sqlalchemy~=2.0.29", @@ -81,7 +80,6 @@ dependencies = [ "unstructured[docx,epub,md,ppt,pptx]~=0.21.5", "pypandoc~=1.13", "yarl~=1.23.0", - "webvtt-py~=0.5.1", "sseclient-py~=1.9.0", "httpx-sse~=0.4.0", "sendgrid~=6.12.3", @@ -130,7 +128,6 @@ dev = [ "types-defusedxml~=0.7.0", "types-deprecated~=1.3.1", "types-docutils~=0.22.3", - "types-jsonschema~=4.26.0", "types-flask-cors~=6.0.0", "types-flask-migrate~=4.1.0", "types-gevent~=25.9.0", @@ -150,7 +147,7 @@ dev = [ "types-python-dateutil~=2.9.0", "types-pywin32~=311.0.0", "types-pyyaml~=6.0.12", - "types-regex~=2026.2.28", + "types-regex~=2026.3.32", "types-shapely~=2.1.0", "types-simplejson>=3.20.0", "types-six>=1.17.0", @@ -206,7 +203,7 @@ vdb = [ "alibabacloud_gpdb20160503~=5.1.0", "alibabacloud_tea_openapi~=0.4.3", "chromadb==0.5.20", - "clickhouse-connect~=0.14.1", + "clickhouse-connect~=0.15.0", "clickzetta-connector-python>=0.8.102", "couchbase~=4.5.0", "elasticsearch==8.14.0", @@ -219,38 +216,18 @@ vdb = [ "pyobvector~=0.2.17", "qdrant-client==1.9.0", "intersystems-irispython>=5.1.0", - "tablestore==6.4.1", - "tcvectordb~=2.0.0", + "tablestore==6.4.2", + "tcvectordb~=2.1.0", "tidb-vector==0.0.15", "upstash-vector==0.8.0", "volcengine-compat~=1.0.0", "weaviate-client==4.20.4", - "xinference-client~=2.3.1", + "xinference-client~=2.4.0", "mo-vector~=0.1.13", "mysql-connector-python>=9.3.0", "holo-search-sdk>=0.4.1", ] -[tool.mypy] - -[[tool.mypy.overrides]] -# targeted ignores for current type-check errors -# TODO(QuantumGhost): suppress type errors in HITL related code. -# fix the type error later -module = [ - "configs.middleware.cache.redis_pubsub_config", - "extensions.ext_redis", - "tasks.workflow_execution_tasks", - "dify_graph.nodes.base.node", - "services.human_input_delivery_test_service", - "core.app.apps.advanced_chat.app_generator", - "controllers.console.human_input_form", - "controllers.console.app.workflow_run", - "repositories.sqlalchemy_api_workflow_node_execution_repository", - "extensions.logstore.repositories.logstore_api_workflow_run_repository", -] -ignore_errors = true - [tool.pyrefly] project-includes = ["."] project-excludes = [".venv", "migrations/"] diff --git a/api/pyrefly-local-excludes.txt b/api/pyrefly-local-excludes.txt index ad3c1e8389..43f604c2de 100644 --- a/api/pyrefly-local-excludes.txt +++ b/api/pyrefly-local-excludes.txt @@ -109,34 +109,16 @@ core/trigger/debug/event_selectors.py core/trigger/entities/entities.py core/trigger/provider.py core/workflow/workflow_entry.py -dify_graph/entities/workflow_execution.py -dify_graph/file/file_manager.py -dify_graph/graph_engine/error_handler.py -dify_graph/graph_engine/layers/execution_limits.py -dify_graph/nodes/agent/agent_node.py -dify_graph/nodes/base/node.py -dify_graph/nodes/code/code_node.py -dify_graph/nodes/datasource/datasource_node.py -dify_graph/nodes/document_extractor/node.py -dify_graph/nodes/human_input/human_input_node.py -dify_graph/nodes/if_else/if_else_node.py -dify_graph/nodes/iteration/iteration_node.py -dify_graph/nodes/knowledge_index/knowledge_index_node.py +enterprise/telemetry/contracts.py +enterprise/telemetry/draft_trace.py +enterprise/telemetry/enterprise_trace.py +enterprise/telemetry/entities/__init__.py +enterprise/telemetry/event_handlers.py +enterprise/telemetry/exporter.py +enterprise/telemetry/id_generator.py +enterprise/telemetry/metric_handler.py +enterprise/telemetry/telemetry_log.py core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py -dify_graph/nodes/list_operator/node.py -dify_graph/nodes/llm/node.py -dify_graph/nodes/loop/loop_node.py -dify_graph/nodes/parameter_extractor/parameter_extractor_node.py -dify_graph/nodes/question_classifier/question_classifier_node.py -dify_graph/nodes/start/start_node.py -dify_graph/nodes/template_transform/template_transform_node.py -dify_graph/nodes/tool/tool_node.py -dify_graph/nodes/trigger_plugin/trigger_event_node.py -dify_graph/nodes/trigger_schedule/trigger_schedule_node.py -dify_graph/nodes/trigger_webhook/node.py -dify_graph/nodes/variable_aggregator/variable_aggregator_node.py -dify_graph/nodes/variable_assigner/v1/node.py -dify_graph/nodes/variable_assigner/v2/node.py extensions/logstore/repositories/logstore_api_workflow_run_repository.py extensions/otel/instrumentation.py extensions/otel/runtime.py diff --git a/api/repositories/api_workflow_node_execution_repository.py b/api/repositories/api_workflow_node_execution_repository.py index 2fa065bcc8..3595ea33f0 100644 --- a/api/repositories/api_workflow_node_execution_repository.py +++ b/api/repositories/api_workflow_node_execution_repository.py @@ -16,7 +16,7 @@ from typing import Protocol from sqlalchemy.orm import Session -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import WorkflowNodeExecutionRepository from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/repositories/api_workflow_run_repository.py b/api/repositories/api_workflow_run_repository.py index a96c4acb31..1a2a539c80 100644 --- a/api/repositories/api_workflow_run_repository.py +++ b/api/repositories/api_workflow_run_repository.py @@ -38,11 +38,11 @@ from collections.abc import Callable, Sequence from datetime import datetime from typing import Protocol +from graphon.entities.pause_reason import PauseReason +from graphon.enums import WorkflowType from sqlalchemy.orm import Session -from dify_graph.entities.pause_reason import PauseReason -from dify_graph.enums import WorkflowType -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository +from core.repositories.factory import WorkflowExecutionRepository from libs.infinite_scroll_pagination import InfiniteScrollPagination from models.enums import WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun diff --git a/api/repositories/entities/workflow_pause.py b/api/repositories/entities/workflow_pause.py index be28b7e613..03ce574dca 100644 --- a/api/repositories/entities/workflow_pause.py +++ b/api/repositories/entities/workflow_pause.py @@ -10,7 +10,7 @@ from abc import ABC, abstractmethod from collections.abc import Sequence from datetime import datetime -from dify_graph.entities.pause_reason import PauseReason +from graphon.entities.pause_reason import PauseReason class WorkflowPauseEntity(ABC): diff --git a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py index 77e40fc6fc..d5c6a203b1 100644 --- a/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_node_execution_repository.py @@ -10,11 +10,11 @@ from collections.abc import Sequence from datetime import datetime from typing import Protocol, cast +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from sqlalchemy import asc, delete, desc, func, select from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload from repositories.api_workflow_node_execution_repository import ( DifyAPIWorkflowNodeExecutionRepository, diff --git a/api/repositories/sqlalchemy_api_workflow_run_repository.py b/api/repositories/sqlalchemy_api_workflow_run_repository.py index fdd3e123e4..413936b542 100644 --- a/api/repositories/sqlalchemy_api_workflow_run_repository.py +++ b/api/repositories/sqlalchemy_api_workflow_run_repository.py @@ -28,14 +28,14 @@ from decimal import Decimal from typing import Any, cast import sqlalchemy as sa +from graphon.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause +from graphon.enums import WorkflowExecutionStatus, WorkflowType +from graphon.nodes.human_input.entities import FormDefinition from pydantic import ValidationError from sqlalchemy import and_, delete, func, null, or_, select, tuple_ from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session, selectinload, sessionmaker -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause -from dify_graph.enums import WorkflowExecutionStatus, WorkflowType -from dify_graph.nodes.human_input.entities import FormDefinition from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from libs.helper import convert_datetime_to_date @@ -43,7 +43,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination from libs.time_parser import get_time_threshold from libs.uuid_utils import uuidv7 from models.enums import WorkflowRunTriggeredFrom -from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType +from models.human_input import HumanInputForm from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -61,25 +61,13 @@ class _WorkflowRunError(Exception): pass -def _select_recipient_token( - recipients: Sequence[HumanInputFormRecipient], - recipient_type: RecipientType, -) -> str | None: - for recipient in recipients: - if recipient.recipient_type == recipient_type and recipient.access_token: - return recipient.access_token - return None - - def _build_human_input_required_reason( reason_model: WorkflowPauseReason, form_model: HumanInputForm | None, - recipients: Sequence[HumanInputFormRecipient], ) -> HumanInputRequired: form_content = "" inputs = [] actions = [] - display_in_ui = False resolved_default_values: dict[str, Any] = {} node_title = "Human Input" form_id = reason_model.form_id @@ -99,25 +87,16 @@ def _build_human_input_required_reason( form_content = definition.form_content inputs = list(definition.inputs) actions = list(definition.user_actions) - display_in_ui = bool(definition.display_in_ui) resolved_default_values = dict(definition.default_values) node_title = definition.node_title or node_title - form_token = ( - _select_recipient_token(recipients, RecipientType.BACKSTAGE) - or _select_recipient_token(recipients, RecipientType.CONSOLE) - or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP) - ) - return HumanInputRequired( form_id=form_id, form_content=form_content, inputs=inputs, actions=actions, - display_in_ui=display_in_ui, node_id=node_id, node_title=node_title, - form_token=form_token, resolved_default_values=resolved_default_values, ) @@ -823,22 +802,16 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository): if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id ] form_models: dict[str, HumanInputForm] = {} - recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {} if form_ids: form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids)) for form in session.scalars(form_stmt).all(): form_models[form.id] = form - recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids)) - for recipient in session.scalars(recipient_stmt).all(): - recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient) - pause_reasons: list[PauseReason] = [] for reason in pause_reason_models: if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED: form_model = form_models.get(reason.form_id) - recipients = recipient_models_by_form.get(reason.form_id, []) - pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients)) + pause_reasons.append(_build_human_input_required_reason(reason, form_model)) else: pause_reasons.append(reason.to_entity()) return pause_reasons diff --git a/api/repositories/sqlalchemy_execution_extra_content_repository.py b/api/repositories/sqlalchemy_execution_extra_content_repository.py index 508db22eb0..feba5f7eb6 100644 --- a/api/repositories/sqlalchemy_execution_extra_content_repository.py +++ b/api/repositories/sqlalchemy_execution_extra_content_repository.py @@ -7,6 +7,9 @@ from collections import defaultdict from collections.abc import Sequence from typing import Any +from graphon.nodes.human_input.entities import FormDefinition +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode from sqlalchemy import select from sqlalchemy.orm import Session, selectinload, sessionmaker @@ -18,9 +21,6 @@ from core.entities.execution_extra_content import ( from core.entities.execution_extra_content import ( HumanInputContent as HumanInputContentDomainModel, ) -from dify_graph.nodes.human_input.entities import FormDefinition -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode from models.execution_extra_content import ( ExecutionExtraContent as ExecutionExtraContentModel, ) diff --git a/api/services/app_dsl_service.py b/api/services/app_dsl_service.py index 68cb3438ca..dd73e10374 100644 --- a/api/services/app_dsl_service.py +++ b/api/services/app_dsl_service.py @@ -11,6 +11,12 @@ from uuid import uuid4 import yaml from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from packaging.version import parse as parse_version from pydantic import BaseModel, Field @@ -27,12 +33,6 @@ from core.trigger.constants import ( ) from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData from core.workflow.nodes.trigger_schedule.trigger_schedule_node import TriggerScheduleNode -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.tool.entities import ToolNodeData from events.app_event import app_model_config_was_updated, app_was_created from extensions.ext_redis import redis_client from factories import variable_factory diff --git a/api/services/app_service.py b/api/services/app_service.py index 69c7c0c95a..e9aeb6c43d 100644 --- a/api/services/app_service.py +++ b/api/services/app_service.py @@ -4,6 +4,8 @@ from typing import Any, TypedDict, cast import sqlalchemy as sa from flask_sqlalchemy.pagination import Pagination +from graphon.model_runtime.entities.model_entities import ModelPropertyKey, ModelType +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from configs import dify_config from constants.model_template import default_app_templates @@ -12,9 +14,7 @@ from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.model_manager import ModelManager from core.tools.tool_manager import ToolManager from core.tools.utils.configuration import ToolParameterConfigurationManager -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel -from events.app_event import app_was_created +from events.app_event import app_was_created, app_was_deleted, app_was_updated from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from libs.login import current_user @@ -92,7 +92,7 @@ class AppService: default_model_config = default_model_config.copy() if default_model_config else None if default_model_config and "model" in default_model_config: # get model provider - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=account.current_tenant_id or "") # get default model instance try: @@ -124,11 +124,19 @@ class AppService: "completion_params": {}, } else: - provider, model = model_manager.get_default_provider_model_name( - tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM - ) - default_model_config["model"]["provider"] = provider - default_model_config["model"]["name"] = model + try: + provider, model = model_manager.get_default_provider_model_name( + tenant_id=account.current_tenant_id or "", model_type=ModelType.LLM + ) + except Exception: + logger.exception("Get default provider model failed, tenant_id: %s", tenant_id) + provider = default_model_config["model"].get("provider") + model = default_model_config["model"].get("name") + + if provider: + default_model_config["model"]["provider"] = provider + if model: + default_model_config["model"]["name"] = model default_model_dict = default_model_config["model"] default_model_config["model"] = json.dumps(default_model_dict) @@ -197,6 +205,7 @@ class AppService: tenant_id=current_user.current_tenant_id, app_id=app.id, agent_tool=agent_tool_entity, + user_id=current_user.id, ) manager = ToolParameterConfigurationManager( tenant_id=current_user.current_tenant_id, @@ -272,6 +281,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_name(self, app: App, name: str) -> App: @@ -287,6 +298,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_icon(self, app: App, icon: str, icon_background: str) -> App: @@ -304,6 +317,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_site_status(self, app: App, enable_site: bool) -> App: @@ -321,6 +336,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def update_app_api_status(self, app: App, enable_api: bool) -> App: @@ -339,6 +356,8 @@ class AppService: app.updated_at = naive_utc_now() db.session.commit() + app_was_updated.send(app) + return app def delete_app(self, app: App): @@ -346,6 +365,8 @@ class AppService: Delete app :param app: App instance """ + app_was_deleted.send(app) + db.session.delete(app) db.session.commit() diff --git a/api/services/app_task_service.py b/api/services/app_task_service.py index d556230044..0842e9d3e7 100644 --- a/api/services/app_task_service.py +++ b/api/services/app_task_service.py @@ -5,9 +5,10 @@ like stopping tasks, handling both legacy Redis flag mechanism and new GraphEngine command channel mechanism. """ +from graphon.graph_engine.manager import GraphEngineManager + from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.graph_engine.manager import GraphEngineManager from extensions.ext_redis import redis_client from models.model import AppMode diff --git a/api/services/audio_service.py b/api/services/audio_service.py index 1794ea9947..90e72d5f34 100644 --- a/api/services/audio_service.py +++ b/api/services/audio_service.py @@ -5,11 +5,11 @@ from collections.abc import Generator from typing import cast from flask import Response, stream_with_context +from graphon.model_runtime.entities.model_entities import ModelType from werkzeug.datastructures import FileStorage from constants import AUDIO_EXTENSIONS from core.model_manager import ModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models.enums import MessageStatus from models.model import App, AppMode, Message @@ -61,7 +61,7 @@ class AudioService: message = f"Audio size larger than {FILE_SIZE} mb" raise AudioTooLargeServiceError(message) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.SPEECH2TEXT ) @@ -71,7 +71,7 @@ class AudioService: buffer = io.BytesIO(file_content) buffer.name = "temp.mp3" - return {"text": model_instance.invoke_speech2text(file=buffer, user=end_user)} + return {"text": model_instance.invoke_speech2text(file=buffer)} @classmethod def transcript_tts( @@ -109,7 +109,7 @@ class AudioService: voice = cast(str | None, text_to_speech_dict.get("voice")) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id, user_id=end_user) model_instance = model_manager.get_default_model_instance( tenant_id=app_model.tenant_id, model_type=ModelType.TTS ) @@ -123,9 +123,7 @@ class AudioService: else: raise ValueError("Sorry, no voice available.") - return model_instance.invoke_tts( - content_text=text_content.strip(), user=end_user, tenant_id=app_model.tenant_id, voice=voice - ) + return model_instance.invoke_tts(content_text=text_content.strip(), voice=voice) except Exception as e: raise e @@ -155,7 +153,7 @@ class AudioService: @classmethod def transcript_tts_voices(cls, tenant_id: str, language: str): - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_default_model_instance(tenant_id=tenant_id, model_type=ModelType.TTS) if model_instance is None: raise ProviderNotSupportTextToSpeechServiceError() diff --git a/api/services/auth/api_key_auth_service.py b/api/services/auth/api_key_auth_service.py index 56aaf407ee..3282dcfb11 100644 --- a/api/services/auth/api_key_auth_service.py +++ b/api/services/auth/api_key_auth_service.py @@ -35,15 +35,13 @@ class ApiKeyAuthService: @staticmethod def get_auth_credentials(tenant_id: str, category: str, provider: str): - data_source_api_key_bindings = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where( + data_source_api_key_bindings = db.session.scalar( + select(DataSourceApiKeyAuthBinding).where( DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.category == category, DataSourceApiKeyAuthBinding.provider == provider, DataSourceApiKeyAuthBinding.disabled.is_(False), ) - .first() ) if not data_source_api_key_bindings: return None @@ -54,10 +52,11 @@ class ApiKeyAuthService: @staticmethod def delete_provider_auth(tenant_id: str, binding_id: str): - data_source_api_key_binding = ( - db.session.query(DataSourceApiKeyAuthBinding) - .where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id) - .first() + data_source_api_key_binding = db.session.scalar( + select(DataSourceApiKeyAuthBinding).where( + DataSourceApiKeyAuthBinding.tenant_id == tenant_id, + DataSourceApiKeyAuthBinding.id == binding_id, + ) ) if data_source_api_key_binding: db.session.delete(data_source_api_key_binding) diff --git a/api/services/clear_free_plan_tenant_expired_logs.py b/api/services/clear_free_plan_tenant_expired_logs.py index 0e0eab00ad..1c128524ad 100644 --- a/api/services/clear_free_plan_tenant_expired_logs.py +++ b/api/services/clear_free_plan_tenant_expired_logs.py @@ -6,11 +6,11 @@ from concurrent.futures import ThreadPoolExecutor import click from flask import Flask, current_app +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from enums.cloud_plan import CloudPlan from extensions.ext_database import db from extensions.ext_storage import storage diff --git a/api/services/conversation_service.py b/api/services/conversation_service.py index 566c27c0f3..ba1e7bb826 100644 --- a/api/services/conversation_service.py +++ b/api/services/conversation_service.py @@ -3,6 +3,7 @@ import logging from collections.abc import Callable, Sequence from typing import Any, Union +from graphon.variables.types import SegmentType from sqlalchemy import asc, desc, func, or_, select from sqlalchemy.orm import Session @@ -10,7 +11,6 @@ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.db.session_factory import session_factory from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.variables.types import SegmentType from extensions.ext_database import db from factories import variable_factory from libs.datetime_utils import naive_utc_now diff --git a/api/services/conversation_variable_updater.py b/api/services/conversation_variable_updater.py index f00e3fe01e..95a8951951 100644 --- a/api/services/conversation_variable_updater.py +++ b/api/services/conversation_variable_updater.py @@ -1,7 +1,7 @@ +from graphon.variables.variables import VariableBase from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.variables.variables import VariableBase from models import ConversationVariable diff --git a/api/services/dataset_service.py b/api/services/dataset_service.py index 969ca68545..83363125c3 100644 --- a/api/services/dataset_service.py +++ b/api/services/dataset_service.py @@ -10,6 +10,9 @@ from collections.abc import Sequence from typing import Any, Literal, cast import sqlalchemy as sa +from graphon.file import helpers as file_helpers +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from redis.exceptions import LockNotOwnedError from sqlalchemy import exists, func, select from sqlalchemy.orm import Session @@ -23,9 +26,6 @@ from core.model_manager import ModelManager from core.rag.index_processor.constant.built_in_field import BuiltInField from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.file import helpers as file_helpers -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelType -from dify_graph.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel from enums.cloud_plan import CloudPlan from events.dataset_event import dataset_was_deleted from events.document_event import document_was_deleted @@ -229,7 +229,7 @@ class DatasetService: raise DatasetNameDuplicateError(f"Dataset with name {name} already exists.") embedding_model = None if indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) if embedding_model_provider and embedding_model_name: # check if embedding model setting is valid DatasetService.check_embedding_model_setting(tenant_id, embedding_model_provider, embedding_model_name) @@ -354,7 +354,7 @@ class DatasetService: def check_dataset_model_setting(dataset): if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -371,7 +371,7 @@ class DatasetService: @staticmethod def check_embedding_model_setting(tenant_id: str, embedding_model_provider: str, embedding_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=embedding_model_provider, @@ -388,7 +388,7 @@ class DatasetService: @staticmethod def check_is_multimodal_model(tenant_id: str, model_provider: str, model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_instance = model_manager.get_model_instance( tenant_id=tenant_id, provider=model_provider, @@ -409,7 +409,7 @@ class DatasetService: @staticmethod def check_reranking_model_setting(tenant_id: str, reranking_model_provider: str, reranking_model: str): try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=tenant_id) model_manager.get_model_instance( tenant_id=tenant_id, provider=reranking_model_provider, @@ -746,7 +746,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None embedding_model = model_manager.get_model_instance( @@ -864,7 +864,7 @@ class DatasetService: """ # assert isinstance(current_user, Account) and current_user.current_tenant_id is not None - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) try: assert isinstance(current_user, Account) assert current_user.current_tenant_id is not None @@ -958,7 +958,7 @@ class DatasetService: dataset.chunk_structure = knowledge_configuration.chunk_structure dataset.indexing_technique = IndexTechniqueType(knowledge_configuration.indexing_technique) if knowledge_configuration.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, # ignore type error provider=knowledge_configuration.embedding_model_provider or "", @@ -1000,7 +1000,7 @@ class DatasetService: action = "add" # get embedding model setting try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=knowledge_configuration.embedding_model_provider, @@ -1053,7 +1053,7 @@ class DatasetService: or knowledge_configuration.embedding_model != dataset.embedding_model ): action = "update" - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = None try: embedding_model = model_manager.get_model_instance( @@ -1912,7 +1912,7 @@ class DocumentService: dataset.indexing_technique = IndexTechniqueType(knowledge_config.indexing_technique) if knowledge_config.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: dataset_embedding_model = knowledge_config.embedding_model dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -2224,7 +2224,7 @@ class DocumentService: # dataset.indexing_technique = knowledge_config.indexing_technique # if knowledge_config.indexing_technique == "high_quality": - # model_manager = ModelManager() + # model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) # if knowledge_config.embedding_model and knowledge_config.embedding_model_provider: # dataset_embedding_model = knowledge_config.embedding_model # dataset_embedding_model_provider = knowledge_config.embedding_model_provider @@ -3129,7 +3129,7 @@ class SegmentService: segment_hash = helper.generate_text_hash(content) tokens = 0 if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3212,7 +3212,7 @@ class SegmentService: with redis_client.lock(lock_name, timeout=600): embedding_model = None if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3350,7 +3350,7 @@ class SegmentService: # get embedding model instance if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( @@ -3413,7 +3413,7 @@ class SegmentService: segment_hash = helper.generate_text_hash(content) tokens = 0 if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=current_user.current_tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=current_user.current_tenant_id, provider=dataset.embedding_model_provider, @@ -3454,7 +3454,7 @@ class SegmentService: # get embedding model instance if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( diff --git a/api/services/datasource_provider_service.py b/api/services/datasource_provider_service.py index f3b2adb965..06f83a18f7 100644 --- a/api/services/datasource_provider_service.py +++ b/api/services/datasource_provider_service.py @@ -3,6 +3,7 @@ import time from collections.abc import Mapping from typing import Any +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from configs import dify_config @@ -14,7 +15,6 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.plugin.impl.datasource import PluginDatasourceManager from core.plugin.impl.oauth import OAuthHandler from core.tools.utils.encryption import ProviderConfigCache, ProviderConfigEncrypter, create_provider_encrypter -from dify_graph.model_runtime.entities.provider_entities import FormType from extensions.ext_database import db from extensions.ext_redis import redis_client from models.oauth import DatasourceOauthParamConfig, DatasourceOauthTenantParamConfig, DatasourceProvider diff --git a/api/services/entities/model_provider_entities.py b/api/services/entities/model_provider_entities.py index 9dd595f516..a944ef6acd 100644 --- a/api/services/entities/model_provider_entities.py +++ b/api/services/entities/model_provider_entities.py @@ -1,6 +1,15 @@ from collections.abc import Sequence from enum import StrEnum +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderHelpEntity, + SimpleProviderEntity, +) from pydantic import BaseModel, ConfigDict, model_validator from configs import dify_config @@ -15,15 +24,6 @@ from core.entities.provider_entities import ( QuotaConfiguration, UnaddedModelConfiguration, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderHelpEntity, - SimpleProviderEntity, -) from models.provider import ProviderType diff --git a/api/services/external_knowledge_service.py b/api/services/external_knowledge_service.py index 4cf42b7f44..64852c222f 100644 --- a/api/services/external_knowledge_service.py +++ b/api/services/external_knowledge_service.py @@ -4,12 +4,12 @@ from typing import Any, Union, cast from urllib.parse import urlparse import httpx +from graphon.nodes.http_request.exc import InvalidHttpMethodError from sqlalchemy import select from constants import HIDDEN_VALUE from core.helper import ssrf_proxy from core.rag.entities.metadata_entities import MetadataCondition -from dify_graph.nodes.http_request.exc import InvalidHttpMethodError from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.dataset import ( diff --git a/api/services/file_service.py b/api/services/file_service.py index a7060f3b92..50a326d813 100644 --- a/api/services/file_service.py +++ b/api/services/file_service.py @@ -8,6 +8,7 @@ from tempfile import NamedTemporaryFile from typing import Literal, Union from zipfile import ZIP_DEFLATED, ZipFile +from graphon.file import helpers as file_helpers from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker from werkzeug.exceptions import NotFound @@ -20,7 +21,6 @@ from constants import ( VIDEO_EXTENSIONS, ) from core.rag.extractor.extract_processor import ExtractProcessor -from dify_graph.file import helpers as file_helpers from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType diff --git a/api/services/hit_testing_service.py b/api/services/hit_testing_service.py index 9993d24c70..82e0b0f8b1 100644 --- a/api/services/hit_testing_service.py +++ b/api/services/hit_testing_service.py @@ -3,13 +3,14 @@ import logging import time from typing import Any +from graphon.model_runtime.entities import LLMMode + from core.app.app_config.entities import ModelConfig from core.rag.datasource.retrieval_service import RetrievalService from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities import LLMMode from extensions.ext_database import db from models import Account from models.dataset import Dataset, DatasetQuery diff --git a/api/services/human_input_delivery_test_service.py b/api/services/human_input_delivery_test_service.py index 229e6608da..77576fa4c0 100644 --- a/api/services/human_input_delivery_test_service.py +++ b/api/services/human_input_delivery_test_service.py @@ -4,18 +4,18 @@ from dataclasses import dataclass, field from enum import StrEnum from typing import Protocol +from graphon.runtime import VariablePool from sqlalchemy import Engine, select from sqlalchemy.orm import sessionmaker from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, ExternalRecipient, MemberRecipient, ) -from dify_graph.runtime import VariablePool from extensions.ext_database import db from extensions.ext_mail import mail from libs.email_template_renderer import render_email_template @@ -177,21 +177,21 @@ class EmailDeliveryTestHandler: def _resolve_recipients(self, *, tenant_id: str, method: EmailDeliveryMethod) -> list[str]: recipients = method.config.recipients emails: list[str] = [] - member_user_ids: list[str] = [] + bound_reference_ids: list[str] = [] for recipient in recipients.items: if isinstance(recipient, MemberRecipient): - member_user_ids.append(recipient.user_id) + bound_reference_ids.append(recipient.reference_id) elif isinstance(recipient, ExternalRecipient): if recipient.email: emails.append(recipient.email) - if recipients.whole_workspace: - member_user_ids = [] + if recipients.include_bound_group: + bound_reference_ids = [] member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=None) emails.extend(member_emails.values()) - elif member_user_ids: - member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=member_user_ids) - for user_id in member_user_ids: + elif bound_reference_ids: + member_emails = self._query_workspace_member_emails(tenant_id=tenant_id, user_ids=bound_reference_ids) + for user_id in bound_reference_ids: email = member_emails.get(user_id) if email: emails.append(email) @@ -220,7 +220,7 @@ class EmailDeliveryTestHandler: stmt = stmt.where(Account.id.in_(unique_ids)) with self._session_factory() as session: - rows = session.execute(stmt).all() + rows = session.execute(stmt).tuples().all() return dict(rows) @staticmethod diff --git a/api/services/human_input_service.py b/api/services/human_input_service.py index 2e74c50963..02a6620fc7 100644 --- a/api/services/human_input_service.py +++ b/api/services/human_input_service.py @@ -3,6 +3,12 @@ from collections.abc import Mapping from datetime import datetime, timedelta from typing import Any +from graphon.nodes.human_input.entities import ( + FormDefinition, + HumanInputSubmissionValidationError, + validate_human_input_submission, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -11,12 +17,6 @@ from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from dify_graph.nodes.human_input.entities import ( - FormDefinition, - HumanInputSubmissionValidationError, - validate_human_input_submission, -) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import ensure_naive_utc, naive_utc_now from libs.exception import BaseHTTPException from models.human_input import RecipientType diff --git a/api/services/knowledge_service.py b/api/services/knowledge_service.py index 02fe1d19bc..3d6fdb08a3 100644 --- a/api/services/knowledge_service.py +++ b/api/services/knowledge_service.py @@ -1,12 +1,20 @@ import boto3 +from pydantic import BaseModel, Field from configs import dify_config +class BedrockRetrievalSetting(BaseModel): + """Retrieval settings for Amazon Bedrock knowledge base queries.""" + + top_k: int | None = Field(default=None, description="Maximum number of results to retrieve") + score_threshold: float = Field(default=0.0, description="Minimum relevance score threshold") + + class ExternalDatasetTestService: # this service is only for internal testing @staticmethod - def knowledge_retrieval(retrieval_setting: dict, query: str, knowledge_id: str): + def knowledge_retrieval(retrieval_setting: BedrockRetrievalSetting, query: str, knowledge_id: str): # get bedrock client client = boto3.client( "bedrock-agent-runtime", @@ -20,7 +28,7 @@ class ExternalDatasetTestService: knowledgeBaseId=knowledge_id, retrievalConfiguration={ "vectorSearchConfiguration": { - "numberOfResults": retrieval_setting.get("top_k"), + "numberOfResults": retrieval_setting.top_k, "overrideSearchType": "HYBRID", } }, @@ -33,7 +41,7 @@ class ExternalDatasetTestService: retrieval_results = response.get("retrievalResults") for retrieval_result in retrieval_results: # filter out results with score less than threshold - if retrieval_result.get("score") < retrieval_setting.get("score_threshold", 0.0): + if retrieval_result.get("score") < retrieval_setting.score_threshold: continue result = { "metadata": retrieval_result.get("metadata"), diff --git a/api/services/message_service.py b/api/services/message_service.py index fc87802f51..e5389ef659 100644 --- a/api/services/message_service.py +++ b/api/services/message_service.py @@ -2,6 +2,7 @@ import json from collections.abc import Sequence from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import sessionmaker from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager @@ -12,7 +13,6 @@ from core.model_manager import ModelManager from core.ops.entities.trace_entity import TraceTaskName from core.ops.ops_trace_manager import TraceQueueManager, TraceTask from core.ops.utils import measure_time -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -255,7 +255,7 @@ class MessageService: app_model=app_model, conversation_id=message.conversation_id, user=user ) - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=app_model.tenant_id) if app_model.mode == AppMode.ADVANCED_CHAT: workflow_service = WorkflowService() diff --git a/api/services/model_load_balancing_service.py b/api/services/model_load_balancing_service.py index bf3b6db3ed..91cca5cb6d 100644 --- a/api/services/model_load_balancing_service.py +++ b/api/services/model_load_balancing_service.py @@ -3,6 +3,12 @@ import logging from json import JSONDecodeError from typing import Union +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ModelCredentialSchema, + ProviderCredentialSchema, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from sqlalchemy import or_, select from constants import HIDDEN_VALUE @@ -10,13 +16,8 @@ from core.entities.provider_configuration import ProviderConfiguration from core.helper import encrypter from core.helper.model_provider_cache import ProviderCredentialsCache, ProviderCredentialsCacheType from core.model_manager import LBModelManager +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ModelCredentialSchema, - ProviderCredentialSchema, -) -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from extensions.ext_database import db from libs.datetime_utils import naive_utc_now from models.enums import CredentialSourceType @@ -26,8 +27,9 @@ logger = logging.getLogger(__name__) class ModelLoadBalancingService: - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def enable_model_load_balancing(self, tenant_id: str, provider: str, model: str, model_type: str): """ @@ -40,7 +42,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -61,7 +63,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -83,7 +85,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -222,8 +224,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) + provider_configurations = provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -310,7 +312,7 @@ class ModelLoadBalancingService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -495,8 +497,8 @@ class ModelLoadBalancingService: :param config_id: load balancing config id :return: """ - # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + assembly = create_plugin_model_assembly(tenant_id=tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) # Get provider configuration provider_configuration = provider_configurations.get(provider) @@ -532,6 +534,7 @@ class ModelLoadBalancingService: model=model, credentials=credentials, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=assembly.model_provider_factory, ) def _custom_credentials_validate( @@ -542,6 +545,7 @@ class ModelLoadBalancingService: model: str, credentials: dict, load_balancing_model_config: LoadBalancingModelConfig | None = None, + model_provider_factory: ModelProviderFactory | None = None, validate: bool = True, ): """ @@ -552,6 +556,7 @@ class ModelLoadBalancingService: :param model: model name :param credentials: credentials :param load_balancing_model_config: load balancing model config + :param model_provider_factory: model provider factory sharing the active runtime :param validate: validate credentials :return: """ @@ -581,7 +586,8 @@ class ModelLoadBalancingService: credentials[key] = encrypter.decrypt_token(tenant_id, original_credentials[key]) if validate: - model_provider_factory = ModelProviderFactory(tenant_id) + if model_provider_factory is None: + model_provider_factory = provider_configuration.get_model_provider_factory() if isinstance(credential_schemas, ModelCredentialSchema): credentials = model_provider_factory.model_credentials_validate( provider=provider_configuration.provider.provider, diff --git a/api/services/model_provider_service.py b/api/services/model_provider_service.py index 0ddd6b9b1a..3f37c9b176 100644 --- a/api/services/model_provider_service.py +++ b/api/services/model_provider_service.py @@ -1,9 +1,10 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType, ParameterRule + from core.entities.model_entities import ModelWithProviderEntity, ProviderModelWithStatusEntity +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory, create_plugin_provider_manager from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.model_entities import ModelType, ParameterRule -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory from models.provider import ProviderType from services.entities.model_provider_entities import ( CustomConfigurationResponse, @@ -25,8 +26,9 @@ class ModelProviderService: Model Provider Service """ - def __init__(self): - self.provider_manager = ProviderManager() + @staticmethod + def _get_provider_manager(tenant_id: str) -> ProviderManager: + return create_plugin_provider_manager(tenant_id=tenant_id) def _get_provider_configuration(self, tenant_id: str, provider: str): """ @@ -43,7 +45,7 @@ class ModelProviderService: ProviderNotFoundError: If provider doesn't exist """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) if not provider_configuration: @@ -60,7 +62,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) provider_responses = [] for provider_configuration in provider_configurations.values(): @@ -138,7 +140,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models return [ @@ -146,6 +148,26 @@ class ModelProviderService: for model in provider_configurations.get_models(provider=provider) ] + def get_provider_available_credentials(self, tenant_id: str, provider: str): + return self._get_provider_manager(tenant_id).get_provider_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + ) + + def get_provider_model_available_credentials( + self, + tenant_id: str, + provider: str, + model_type: str, + model: str, + ): + return self._get_provider_manager(tenant_id).get_provider_model_available_credentials( + tenant_id=tenant_id, + provider_name=provider, + model_type=model_type, + model_name=model, + ) + def get_provider_credential(self, tenant_id: str, provider: str, credential_id: str | None = None) -> dict | None: """ get provider credentials. @@ -391,7 +413,7 @@ class ModelProviderService: :return: """ # Get all provider configurations of the current workspace - provider_configurations = self.provider_manager.get_configurations(tenant_id) + provider_configurations = self._get_provider_manager(tenant_id).get_configurations(tenant_id) # Get provider available models models = provider_configurations.get_models(model_type=ModelType.value_of(model_type), only_active=True) @@ -476,7 +498,9 @@ class ModelProviderService: model_type_enum = ModelType.value_of(model_type) try: - result = self.provider_manager.get_default_model(tenant_id=tenant_id, model_type=model_type_enum) + result = self._get_provider_manager(tenant_id).get_default_model( + tenant_id=tenant_id, model_type=model_type_enum + ) return ( DefaultModelResponse( model=result.model, @@ -507,7 +531,7 @@ class ModelProviderService: :return: """ model_type_enum = ModelType.value_of(model_type) - self.provider_manager.update_default_model_record( + self._get_provider_manager(tenant_id).update_default_model_record( tenant_id=tenant_id, model_type=model_type_enum, provider=provider, model=model ) @@ -523,7 +547,7 @@ class ModelProviderService: :param lang: language (zh_Hans or en_US) :return: """ - model_provider_factory = ModelProviderFactory(tenant_id) + model_provider_factory = create_plugin_model_provider_factory(tenant_id=tenant_id) byte_data, mime_type = model_provider_factory.get_provider_icon(provider, icon_type, lang) return byte_data, mime_type diff --git a/api/services/rag_pipeline/rag_pipeline.py b/api/services/rag_pipeline/rag_pipeline.py index 296b9f0890..bcf5973d7b 100644 --- a/api/services/rag_pipeline/rag_pipeline.py +++ b/api/services/rag_pipeline/rag_pipeline.py @@ -9,6 +9,15 @@ from typing import Any, Union, cast from uuid import uuid4 from flask_login import current_user +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, WorkflowNodeExecutionStatus +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable, VariableBase from sqlalchemy import func, select from sqlalchemy.orm import Session, sessionmaker @@ -34,25 +43,19 @@ from core.rag.entities.event import ( DatasourceErrorEvent, DatasourceProcessingEvent, ) -from core.repositories.factory import DifyCoreRepositoryFactory +from core.repositories.factory import DifyCoreRepositoryFactory, OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, +from core.workflow.system_variables import ( + SystemVariableKey, + build_bootstrap_variables, + build_system_variables, + default_system_variables, + get_system_segment, ) -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeType, SystemVariableKey -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.graph_events.base import GraphNodeEventBase -from dify_graph.node_events.base import NodeRunResult -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.variables import VariableBase +from core.workflow.variable_pool_initializer import add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from extensions.ext_database import db from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account @@ -88,6 +91,12 @@ from services.workflow_restore import apply_published_workflow_snapshot_to_draft logger = logging.getLogger(__name__) +def _build_seeded_variable_pool(variables: Sequence[Variable]) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + return variable_pool + + class RagPipelineService: def __init__(self, session_maker: sessionmaker | None = None): """Initialize RagPipelineService with repository dependencies.""" @@ -521,13 +530,7 @@ class RagPipelineService: node_id=node_id, user_inputs=user_inputs, user_id=account.id, - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], - ), + variable_pool=_build_seeded_variable_pool(default_system_variables()), variable_loader=DraftVarLoader( engine=db.engine, app_id=pipeline.id, @@ -571,6 +574,13 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + if workflow_node_execution_db_model is not None: + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=account.id, + ) return workflow_node_execution_db_model def run_datasource_workflow_node( @@ -959,10 +969,10 @@ class RagPipelineService: workflow_node_execution.error = error # update document status variable_pool = node_instance.graph_runtime_state.variable_pool - invoke_from = variable_pool.get(["sys", SystemVariableKey.INVOKE_FROM]) + invoke_from = get_system_segment(variable_pool, SystemVariableKey.INVOKE_FROM) if invoke_from: if invoke_from.value == InvokeFrom.PUBLISHED_PIPELINE: - document_id = variable_pool.get(["sys", SystemVariableKey.DOCUMENT_ID]) + document_id = get_system_segment(variable_pool, SystemVariableKey.DOCUMENT_ID) if document_id: document = db.session.query(Document).where(Document.id == document_id.value).first() if document: @@ -1276,7 +1286,7 @@ class RagPipelineService: else: enclosing_node_id = None - system_inputs = SystemVariable( + system_inputs = build_system_variables( datasource_type=args.get("datasource_type", "online_document"), datasource_info=args.get("datasource_info", {}), ) @@ -1287,12 +1297,11 @@ class RagPipelineService: node_id=node_id, user_inputs={}, user_id=current_user.id, - variable_pool=VariablePool( - system_variables=system_inputs, - user_inputs={}, - environment_variables=[], - conversation_variables=[], - rag_pipeline_variables=[], + variable_pool=_build_seeded_variable_pool( + build_bootstrap_variables( + system_variables=system_inputs, + rag_pipeline_variables=(), + ) ), variable_loader=DraftVarLoader( engine=db.engine, @@ -1334,6 +1343,12 @@ class RagPipelineService: outputs=workflow_node_execution.outputs, ) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution_db_model, + outputs=workflow_node_execution.outputs, + workflow_execution_id=None, + user_id=current_user.id, + ) return workflow_node_execution_db_model def get_recommended_plugins(self, type: str) -> dict: diff --git a/api/services/rag_pipeline/rag_pipeline_dsl_service.py b/api/services/rag_pipeline/rag_pipeline_dsl_service.py index fd66d55c1a..04156713f4 100644 --- a/api/services/rag_pipeline/rag_pipeline_dsl_service.py +++ b/api/services/rag_pipeline/rag_pipeline_dsl_service.py @@ -14,6 +14,12 @@ import yaml # type: ignore from Crypto.Cipher import AES from Crypto.Util.Padding import pad, unpad from flask_login import current_user +from graphon.enums import BuiltinNodeTypes +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes.llm.entities import LLMNodeData +from graphon.nodes.parameter_extractor.entities import ParameterExtractorNodeData +from graphon.nodes.question_classifier.entities import QuestionClassifierNodeData +from graphon.nodes.tool.entities import ToolNodeData from packaging import version from pydantic import BaseModel, Field from sqlalchemy import select @@ -26,12 +32,6 @@ from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.workflow.nodes.datasource.entities import DatasourceNodeData from core.workflow.nodes.knowledge_index import KNOWLEDGE_INDEX_NODE_TYPE from core.workflow.nodes.knowledge_retrieval.entities import KnowledgeRetrievalNodeData -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes.llm.entities import LLMNodeData -from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData -from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData -from dify_graph.nodes.tool.entities import ToolNodeData from extensions.ext_redis import redis_client from factories import variable_factory from models import Account diff --git a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py index 00a2144800..2c1f99a3bc 100644 --- a/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py +++ b/api/services/retention/workflow_run/archive_paid_plan_workflow_run.py @@ -27,11 +27,11 @@ from dataclasses import dataclass, field from typing import Any import click +from graphon.enums import WorkflowType from sqlalchemy import inspect from sqlalchemy.orm import Session, sessionmaker from configs import dify_config -from dify_graph.enums import WorkflowType from enums.cloud_plan import CloudPlan from extensions.ext_database import db from libs.archive_storage import ( diff --git a/api/services/summary_index_service.py b/api/services/summary_index_service.py index ed7a33feae..12053377e2 100644 --- a/api/services/summary_index_service.py +++ b/api/services/summary_index_service.py @@ -6,6 +6,8 @@ import uuid from datetime import UTC, datetime from typing import Any +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.db.session_factory import session_factory @@ -15,8 +17,6 @@ from core.rag.index_processor.constant.doc_type import DocType from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.index_processor_base import SummaryIndexSettingDict from core.rag.models.document import Document -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.model_entities import ModelType from libs import helper from models.dataset import Dataset, DocumentSegment, DocumentSegmentSummary from models.dataset import Document as DatasetDocument @@ -192,7 +192,7 @@ class SummaryIndexService: # Calculate embedding tokens for summary (for logging and statistics) embedding_tokens = 0 try: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) embedding_model = model_manager.get_model_instance( tenant_id=dataset.tenant_id, provider=dataset.embedding_model_provider, @@ -201,7 +201,8 @@ class SummaryIndexService: ) if embedding_model: tokens_list = embedding_model.get_text_embedding_num_tokens([summary_content]) - embedding_tokens = tokens_list[0] if tokens_list else 0 + raw_embedding_tokens = tokens_list[0] if tokens_list else 0 + embedding_tokens = raw_embedding_tokens if isinstance(raw_embedding_tokens, int) else 0 except Exception as e: logger.warning("Failed to calculate embedding tokens for summary: %s", str(e)) diff --git a/api/services/tools/api_tools_manage_service.py b/api/services/tools/api_tools_manage_service.py index 408b1c22d1..2a56bc0c71 100644 --- a/api/services/tools/api_tools_manage_service.py +++ b/api/services/tools/api_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from typing import Any, cast +from graphon.model_runtime.utils.encoders import jsonable_encoder from httpx import get from sqlalchemy import select from typing_extensions import TypedDict @@ -20,7 +21,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.tool_manager import ToolManager from core.tools.utils.encryption import create_tool_provider_encrypter from core.tools.utils.parser import ApiBasedToolSchemaParser -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.tools import ApiToolProvider from services.tools.tools_transform_service import ToolTransformService diff --git a/api/services/tools/builtin_tools_manage_service.py b/api/services/tools/builtin_tools_manage_service.py index 6797a67dde..8e3c36e099 100644 --- a/api/services/tools/builtin_tools_manage_service.py +++ b/api/services/tools/builtin_tools_manage_service.py @@ -275,7 +275,7 @@ class BuiltinToolManageService: user_id=user_id, provider=provider, encrypted_credentials=json.dumps(encrypter.encrypt(credentials)), - credential_type=api_type.value, + credential_type=api_type, name=name, expires_at=expires_at if expires_at is not None else -1, ) @@ -314,7 +314,7 @@ class BuiltinToolManageService: .filter_by( tenant_id=tenant_id, provider=provider, - credential_type=credential_type.value, + credential_type=credential_type, ) .order_by(BuiltinToolProvider.created_at.desc()) .all() diff --git a/api/services/tools/tools_transform_service.py b/api/services/tools/tools_transform_service.py index b6e5367023..b276146066 100644 --- a/api/services/tools/tools_transform_service.py +++ b/api/services/tools/tools_transform_service.py @@ -423,7 +423,7 @@ class ToolTransformService: id=provider.id, name=provider.name, provider=provider.provider, - credential_type=CredentialType.of(provider.credential_type), + credential_type=provider.credential_type, is_default=provider.is_default, credentials=credentials, ) diff --git a/api/services/tools/workflow_tools_manage_service.py b/api/services/tools/workflow_tools_manage_service.py index 101b2fe5a2..fb6b5bea24 100644 --- a/api/services/tools/workflow_tools_manage_service.py +++ b/api/services/tools/workflow_tools_manage_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.model_runtime.utils.encoders import jsonable_encoder from sqlalchemy import or_, select from sqlalchemy.orm import Session @@ -12,7 +13,6 @@ from core.tools.tool_label_manager import ToolLabelManager from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils from core.tools.workflow_as_tool.provider import WorkflowToolProviderController from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.model_runtime.utils.encoders import jsonable_encoder from extensions.ext_database import db from models.model import App from models.tools import WorkflowToolProvider diff --git a/api/services/trigger/schedule_service.py b/api/services/trigger/schedule_service.py index 7e9d010d2f..25e80770b8 100644 --- a/api/services/trigger/schedule_service.py +++ b/api/services/trigger/schedule_service.py @@ -2,6 +2,7 @@ import json import logging from datetime import datetime +from graphon.entities.graph_config import NodeConfigDict from sqlalchemy import select from sqlalchemy.orm import Session @@ -13,7 +14,6 @@ from core.workflow.nodes.trigger_schedule.entities import ( VisualConfig, ) from core.workflow.nodes.trigger_schedule.exc import ScheduleConfigError, ScheduleNotFoundError -from dify_graph.entities.graph_config import NodeConfigDict from libs.schedule_utils import calculate_next_run_at, convert_12h_to_24h from models.account import Account, TenantAccountJoin from models.trigger import WorkflowSchedulePlan diff --git a/api/services/trigger/trigger_provider_service.py b/api/services/trigger/trigger_provider_service.py index 688993c798..008d8bdb8a 100644 --- a/api/services/trigger/trigger_provider_service.py +++ b/api/services/trigger/trigger_provider_service.py @@ -198,7 +198,7 @@ class TriggerProviderService: credentials=dict(credential_encrypter.encrypt(dict(credentials))) if credential_encrypter else {}, - credential_type=credential_type.value, + credential_type=credential_type, credential_expires_at=credential_expires_at, expires_at=expires_at, ) diff --git a/api/services/trigger/trigger_service.py b/api/services/trigger/trigger_service.py index 24bbeda329..d72c041609 100644 --- a/api/services/trigger/trigger_service.py +++ b/api/services/trigger/trigger_service.py @@ -5,6 +5,7 @@ from collections.abc import Mapping from typing import Any from flask import Request, Response +from graphon.entities.graph_config import NodeConfigDict from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -18,7 +19,6 @@ from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.trigger.utils.encryption import create_trigger_provider_encrypter_for_subscription from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.entities.graph_config import NodeConfigDict from extensions.ext_database import db from extensions.ext_redis import redis_client from models.model import App diff --git a/api/services/trigger/webhook_service.py b/api/services/trigger/webhook_service.py index 3c1a4cc747..c03275497d 100644 --- a/api/services/trigger/webhook_service.py +++ b/api/services/trigger/webhook_service.py @@ -7,6 +7,9 @@ from typing import Any import orjson from flask import request +from graphon.entities.graph_config import NodeConfigDict +from graphon.file import FileTransferMethod +from graphon.variables.types import ArrayValidation, SegmentType from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.orm import Session @@ -15,6 +18,7 @@ from werkzeug.exceptions import RequestEntityTooLarge from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.tools.tool_file_manager import ToolFileManager from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( @@ -23,9 +27,6 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, WebhookParameter, ) -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.file.models import FileTransferMethod -from dify_graph.variables.types import ArrayValidation, SegmentType from enums.quota_type import QuotaType from extensions.ext_database import db from extensions.ext_redis import redis_client @@ -46,6 +47,7 @@ except ImportError: magic = None # type: ignore[assignment] logger = logging.getLogger(__name__) +_file_access_controller = DatabaseFileAccessController() class WebhookService: @@ -422,6 +424,7 @@ class WebhookService: return file_factory.build_from_mapping( mapping=mapping, tenant_id=webhook_trigger.tenant_id, + access_controller=_file_access_controller, ) @classmethod diff --git a/api/services/variable_truncator.py b/api/services/variable_truncator.py index 60dc1dedb8..62916cc2c9 100644 --- a/api/services/variable_truncator.py +++ b/api/services/variable_truncator.py @@ -5,10 +5,9 @@ from abc import ABC, abstractmethod from collections.abc import Mapping from typing import Any, Generic, TypeAlias, TypeVar, overload -from configs import dify_config -from dify_graph.file.models import File -from dify_graph.nodes.variable_assigner.common.helpers import UpdatedVariable -from dify_graph.variables.segments import ( +from graphon.file import File +from graphon.nodes.variable_assigner.common.helpers import UpdatedVariable +from graphon.variables.segments import ( ArrayFileSegment, ArraySegment, BooleanSegment, @@ -20,7 +19,9 @@ from dify_graph.variables.segments import ( Segment, StringSegment, ) -from dify_graph.variables.utils import dumps_with_segments +from graphon.variables.utils import dumps_with_segments + +from configs import dify_config _MAX_DEPTH = 100 diff --git a/api/services/vector_service.py b/api/services/vector_service.py index bb94a03ba3..3f78b823a6 100644 --- a/api/services/vector_service.py +++ b/api/services/vector_service.py @@ -1,5 +1,7 @@ import logging +from graphon.model_runtime.entities.model_entities import ModelType + from core.model_manager import ModelInstance, ModelManager from core.rag.datasource.keyword.keyword_factory import Keyword from core.rag.datasource.vdb.vector_factory import Vector @@ -8,7 +10,6 @@ from core.rag.index_processor.constant.index_type import IndexStructureType, Ind from core.rag.index_processor.index_processor_base import BaseIndexProcessor from core.rag.index_processor.index_processor_factory import IndexProcessorFactory from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_database import db from models import UploadFile from models.dataset import ChildChunk, Dataset, DatasetProcessRule, DocumentSegment, SegmentAttachmentBinding @@ -47,7 +48,7 @@ class VectorService: # get embedding model instance if dataset.indexing_technique == IndexTechniqueType.HIGH_QUALITY: # check embedding model setting - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset.tenant_id) if dataset.embedding_model_provider: embedding_model_instance = model_manager.get_model_instance( diff --git a/api/services/workflow/workflow_converter.py b/api/services/workflow/workflow_converter.py index f0596e44c8..31367f72fa 100644 --- a/api/services/workflow/workflow_converter.py +++ b/api/services/workflow/workflow_converter.py @@ -1,6 +1,11 @@ import json from typing import Any +from graphon.file import FileUploadConfig +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.utils.encoders import jsonable_encoder +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.input_entities import VariableEntity from typing_extensions import TypedDict from core.app.app_config.entities import ( @@ -17,11 +22,6 @@ from core.app.apps.completion.app_config_manager import CompletionAppConfigManag from core.helper import encrypter from core.prompt.simple_prompt_transform import SimplePromptTransform from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file.models import FileUploadConfig -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.utils.encoders import jsonable_encoder -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.variables.input_entities import VariableEntity from events.app_event import app_was_created from extensions.ext_database import db from models import Account diff --git a/api/services/workflow_app_service.py b/api/services/workflow_app_service.py index 9489618762..bf178e8a44 100644 --- a/api/services/workflow_app_service.py +++ b/api/services/workflow_app_service.py @@ -3,11 +3,11 @@ import uuid from datetime import datetime from typing import Any +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import and_, func, or_, select from sqlalchemy.orm import Session from typing_extensions import TypedDict -from dify_graph.enums import WorkflowExecutionStatus from models import Account, App, EndUser, TenantAccountJoin, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole from models.trigger import WorkflowTriggerLog diff --git a/api/services/workflow_draft_variable_service.py b/api/services/workflow_draft_variable_service.py index f124e137c3..98e338a2d4 100644 --- a/api/services/workflow_draft_variable_service.py +++ b/api/services/workflow_draft_variable_service.py @@ -6,6 +6,19 @@ from concurrent.futures import ThreadPoolExecutor from enum import StrEnum from typing import Any, ClassVar +from graphon.enums import NodeType +from graphon.file import File +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.variable_assigner.common.helpers import get_updated_variables +from graphon.variable_loader import VariableLoader +from graphon.variables import Segment, StringSegment, VariableBase +from graphon.variables.consts import SELECTORS_LENGTH +from graphon.variables.segments import ( + ArrayFileSegment, + FileSegment, +) +from graphon.variables.types import SegmentType +from graphon.variables.utils import dumps_with_segments from sqlalchemy import Engine, orm, select from sqlalchemy.dialects.mysql import insert as mysql_insert from sqlalchemy.dialects.postgresql import insert as pg_insert @@ -14,21 +27,15 @@ from sqlalchemy.sql.expression import and_, or_ from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom +from core.app.file_access import DatabaseFileAccessController from core.trigger.constants import is_trigger_node_type -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import NodeType, SystemVariableKey -from dify_graph.file.models import File -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.variable_assigner.common.helpers import get_updated_variables -from dify_graph.variable_loader import VariableLoader -from dify_graph.variables import Segment, StringSegment, VariableBase -from dify_graph.variables.consts import SELECTORS_LENGTH -from dify_graph.variables.segments import ( - ArrayFileSegment, - FileSegment, +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.utils import dumps_with_segments from extensions.ext_storage import storage from factories.file_factory import StorageKeyLoader from factories.variable_factory import build_segment, segment_to_variable @@ -36,6 +43,7 @@ from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 from models import Account, App, Conversation from models.enums import ConversationFromSource, DraftVariableType +from models.utils.file_input_compat import build_file_from_stored_mapping from models.workflow import Workflow, WorkflowDraftVariable, WorkflowDraftVariableFile, is_system_variable_editable from repositories.factory import DifyAPIRepositoryFactory from services.file_service import FileService @@ -71,7 +79,7 @@ class UpdateNotSupportedError(WorkflowDraftVariableError): class DraftVarLoader(VariableLoader): # This implements the VariableLoader interface for loading draft variables. # - # ref: dify_graph.variable_loader.VariableLoader + # ref: graphon.variable_loader.VariableLoader # Database engine used for loading variables. _engine: Engine @@ -120,7 +128,11 @@ class DraftVarLoader(VariableLoader): elif isinstance(value, ArrayFileSegment): files.extend(value.value) with Session(bind=self._engine) as session: - storage_key_loader = StorageKeyLoader(session, tenant_id=self._tenant_id) + storage_key_loader = StorageKeyLoader( + session, + tenant_id=self._tenant_id, + access_controller=DatabaseFileAccessController(), + ) storage_key_loader.load_storage_keys(files) offloaded_draft_vars = [] @@ -174,7 +186,7 @@ class DraftVarLoader(VariableLoader): return (draft_var.node_id, draft_var.name), variable deserialized = json.loads(content) - segment = WorkflowDraftVariable.build_segment_with_type(variable_file.value_type, deserialized) + segment = draft_var.build_segment_from_serialized_value(variable_file.value_type, deserialized) variable = segment_to_variable( segment=segment, selector=draft_var.get_selector(), @@ -838,6 +850,12 @@ class DraftVariableSaver: self._user = user self._enclosing_node_id = enclosing_node_id + def _resolve_app_tenant_id(self) -> str: + tenant_id = self._session.scalar(select(App.tenant_id).where(App.id == self._app_id)) + if not tenant_id: + raise ValueError(f"Unable to resolve tenant_id for app {self._app_id}") + return tenant_id + def _create_dummy_output_variable(self): return WorkflowDraftVariable.new_node_variable( app_id=self._app_id, @@ -892,27 +910,18 @@ class DraftVariableSaver: for name, value in output.items(): value_seg = _build_segment_for_serialized_values(value) node_id, name = self._normalize_variable_for_start_node(name) - # If node_id is not `sys`, it means that the variable is a user-defined input field - # in `Start` node. - if node_id != SYSTEM_VARIABLE_NODE_ID: - draft_vars.append( - WorkflowDraftVariable.new_node_variable( - app_id=self._app_id, - user_id=self._user.id, - node_id=self._node_id, - name=name, - node_execution_id=self._node_execution_id, - value=value_seg, - visible=True, - editable=True, - ) - ) - has_non_sys_variables = True - else: + if node_id == SYSTEM_VARIABLE_NODE_ID: if name == SystemVariableKey.FILES: # Here we know the type of variable must be `array[file]`, we - # just build files from the value. - files = [File.model_validate(v) for v in value] + # just rebuild files from the serialized payload. + tenant_id = self._resolve_app_tenant_id() + files = [ + build_file_from_stored_mapping( + file_mapping=v, + tenant_id=tenant_id, + ) + for v in value + ] if files: value_seg = WorkflowDraftVariable.build_segment_with_type(SegmentType.ARRAY_FILE, files) else: @@ -928,15 +937,47 @@ class DraftVariableSaver: editable=self._should_variable_be_editable(node_id, name), ) ) + elif node_id == CONVERSATION_VARIABLE_NODE_ID: + draft_vars.append( + WorkflowDraftVariable.new_conversation_variable( + app_id=self._app_id, + user_id=self._user.id, + name=name, + value=value_seg, + ) + ) + has_non_sys_variables = True + else: + draft_vars.append( + WorkflowDraftVariable.new_node_variable( + app_id=self._app_id, + user_id=self._user.id, + node_id=node_id, + name=name, + node_execution_id=self._node_execution_id, + value=value_seg, + visible=self._should_variable_be_visible(node_id, self._node_type, name), + editable=self._should_variable_be_editable(node_id, name), + ) + ) + has_non_sys_variables = True if not has_non_sys_variables: draft_vars.append(self._create_dummy_output_variable()) return draft_vars def _normalize_variable_for_start_node(self, name: str) -> tuple[str, str]: - if not name.startswith(f"{SYSTEM_VARIABLE_NODE_ID}."): - return self._node_id, name - _, name_ = name.split(".", maxsplit=1) - return SYSTEM_VARIABLE_NODE_ID, name_ + for reserved_node_id in ( + SYSTEM_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + CONVERSATION_VARIABLE_NODE_ID, + RAG_PIPELINE_VARIABLE_NODE_ID, + ): + prefix = f"{reserved_node_id}." + if name.startswith(prefix): + _, name_ = name.split(".", maxsplit=1) + return reserved_node_id, name_ + + return self._node_id, name def _build_variables_from_mapping(self, output: Mapping[str, Any]) -> list[WorkflowDraftVariable]: draft_vars = [] diff --git a/api/services/workflow_event_snapshot_service.py b/api/services/workflow_event_snapshot_service.py index 8f323ebb8b..601e9261fc 100644 --- a/api/services/workflow_event_snapshot_service.py +++ b/api/services/workflow_event_snapshot_service.py @@ -9,6 +9,10 @@ from collections.abc import Generator, Mapping, Sequence from dataclasses import dataclass from typing import Any +from graphon.entities import WorkflowStartReason +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import desc, select from sqlalchemy.orm import Session, sessionmaker @@ -22,10 +26,6 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.entities import WorkflowStartReason -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models.model import AppMode, Message from models.workflow import WorkflowNodeExecutionTriggeredFrom, WorkflowRun from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot diff --git a/api/services/workflow_service.py b/api/services/workflow_service.py index 66976058c0..b555676704 100644 --- a/api/services/workflow_service.py +++ b/api/services/workflow_service.py @@ -5,6 +5,31 @@ import uuid from collections.abc import Callable, Generator, Mapping, Sequence from typing import Any, cast +from graphon.entities import GraphInitParams, WorkflowNodeExecution +from graphon.entities.graph_config import NodeConfigDict +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import ( + ErrorStrategy, + NodeType, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) +from graphon.errors import WorkflowNodeRunFailedError +from graphon.file import File +from graphon.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.node_events import NodeRunResult +from graphon.nodes import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config +from graphon.nodes.human_input.entities import HumanInputNodeData, validate_human_input_submission +from graphon.nodes.human_input.enums import HumanInputFormKind +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variable_loader import load_into_variable_pool +from graphon.variables import VariableBase +from graphon.variables.input_entities import VariableEntityType +from graphon.variables.variables import Variable from sqlalchemy import exists, select from sqlalchemy.orm import Session, sessionmaker @@ -12,43 +37,22 @@ from configs import dify_config from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfigManager from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context +from core.app.file_access import DatabaseFileAccessController +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly, create_plugin_provider_manager from core.repositories import DifyCoreRepositoryFactory -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl from core.trigger.constants import is_trigger_node_type -from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type -from core.workflow.workflow_entry import WorkflowEntry -from dify_graph.entities import GraphInitParams, WorkflowNodeExecution -from dify_graph.entities.graph_config import NodeConfigDict -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import ( - ErrorStrategy, - NodeType, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.file import File -from dify_graph.graph_events import GraphNodeEventBase, NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, build_http_request_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( DeliveryChannelConfig, - HumanInputNodeData, - apply_debug_email_recipient, - validate_human_input_submission, + normalize_human_input_node_data_for_graph, + parse_human_input_delivery_methods, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.repositories.human_input_form_repository import FormCreateParams -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variable_loader import load_into_variable_pool -from dify_graph.variables import VariableBase -from dify_graph.variables.input_entities import VariableEntityType -from dify_graph.variables.variables import Variable +from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, is_start_node_type +from core.workflow.node_runtime import DifyHumanInputNodeRuntime, apply_dify_debug_email_recipient +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables, default_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool +from core.workflow.workflow_entry import WorkflowEntry +from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace from enums.cloud_plan import CloudPlan from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated from extensions.ext_database import db @@ -82,6 +86,8 @@ from .human_input_delivery_test_service import ( from .workflow_draft_variable_service import DraftVariableSaver, DraftVarLoader, WorkflowDraftVariableService from .workflow_restore import apply_published_workflow_snapshot_to_draft +_file_access_controller = DatabaseFileAccessController() + class WorkflowService: """ @@ -486,13 +492,15 @@ class WorkflowService: :raises ValueError: If the model configuration is invalid or credentials fail policy checks """ try: - from core.model_manager import ModelManager - from core.provider_manager import ProviderManager - from dify_graph.model_runtime.entities.model_entities import ModelType + from graphon.model_runtime.entities.model_entities import ModelType + + # Model instance resolution and provider status lookup must reuse the + # same request-scoped runtime so validation does not silently split + # provider discovery and credential reads across different caches. + assembly = create_plugin_model_assembly(tenant_id=tenant_id) # Get model instance to validate provider+model combination - model_manager = ModelManager() - model_manager.get_model_instance( + assembly.model_manager.get_model_instance( tenant_id=tenant_id, provider=provider, model_type=ModelType.LLM, model=model_name ) @@ -501,8 +509,7 @@ class WorkflowService: # If it fails, an exception will be raised # Additionally, check the model status to ensure it's ACTIVE - provider_manager = ProviderManager() - provider_configurations = provider_manager.get_configurations(tenant_id) + provider_configurations = assembly.provider_manager.get_configurations(tenant_id) models = provider_configurations.get_models(provider=provider, model_type=ModelType.LLM) target_model = None @@ -607,11 +614,10 @@ class WorkflowService: :return: True if load balancing is enabled, False otherwise """ try: - from core.provider_manager import ProviderManager - from dify_graph.model_runtime.entities.model_entities import ModelType + from graphon.model_runtime.entities.model_entities import ModelType # Get provider configurations - provider_manager = ProviderManager() + provider_manager = create_plugin_provider_manager(tenant_id=tenant_id) provider_configurations = provider_manager.get_configurations(tenant_id) provider_configuration = provider_configurations.get(provider) @@ -765,6 +771,7 @@ class WorkflowService: user_id=account.id, user_inputs=user_inputs, workflow=draft_workflow, + node_id=node_id, # NOTE(QuantumGhost): We rely on `DraftVarLoader` to load conversation variables. conversation_variables=[], node_type=node_type, @@ -772,11 +779,13 @@ class WorkflowService: ) else: - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs=user_inputs, - environment_variables=draft_workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=draft_workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -841,6 +850,13 @@ class WorkflowService: draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs) session.commit() + enqueue_draft_node_execution_trace( + execution=workflow_node_execution, + outputs=outputs, + workflow_execution_id=None, + user_id=account.id, + ) + return workflow_node_execution def get_human_input_form_preview( @@ -895,7 +911,6 @@ class WorkflowService: node_id=node_id, node_title=node.title, resolved_default_values=resolved_default_values, - form_token=None, ) return human_input_required.model_dump(mode="json") @@ -995,17 +1010,20 @@ class WorkflowService: if node_type != BuiltinNodeTypes.HUMAN_INPUT: raise ValueError("Node type must be human-input.") - node_data = HumanInputNodeData.model_validate(node_config["data"], from_attributes=True) + node_data = HumanInputNodeData.model_validate( + normalize_human_input_node_data_for_graph(node_config["data"]), + from_attributes=True, + ) delivery_method = self._resolve_human_input_delivery_method( node_data=node_data, delivery_method_id=delivery_method_id, ) if delivery_method is None: raise ValueError("Delivery method not found.") - delivery_method = apply_debug_email_recipient( + delivery_method = apply_dify_debug_email_recipient( delivery_method, enabled=True, - user_id=account.id, + actor_id=account.id, ) variable_pool = self._build_human_input_variable_pool( @@ -1055,7 +1073,7 @@ class WorkflowService: node_data: HumanInputNodeData, delivery_method_id: str, ) -> DeliveryChannelConfig | None: - for method in node_data.delivery_methods: + for method in parse_human_input_delivery_methods(node_data): if str(method.id) == delivery_method_id: return method return None @@ -1070,9 +1088,8 @@ class WorkflowService: rendered_content: str, resolved_default_values: Mapping[str, Any], ) -> tuple[str, list[DeliveryTestEmailRecipient]]: - repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id) + repo = HumanInputFormRepositoryImpl(tenant_id=app_model.tenant_id, app_id=app_model.id) params = FormCreateParams( - app_id=app_model.id, workflow_execution_id=None, node_id=node_id, form_config=node_data, @@ -1138,7 +1155,7 @@ class WorkflowService: config=node_config, graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, - form_repository=HumanInputFormRepositoryImpl(tenant_id=workflow.tenant_id), + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) return node @@ -1155,11 +1172,13 @@ class WorkflowService: draft_var_srv = WorkflowDraftVariableService(session) draft_var_srv.prefill_conversation_variable_default_values(workflow, user_id=user_id) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=workflow.environment_variables, - conversation_variables=[], + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=default_system_variables(), + environment_variables=workflow.environment_variables, + ), ) variable_loader = DraftVarLoader( @@ -1419,10 +1438,10 @@ class WorkflowService: Raises: ValueError: If the node data format is invalid """ - from dify_graph.nodes.human_input.entities import HumanInputNodeData + from graphon.nodes.human_input.entities import HumanInputNodeData try: - HumanInputNodeData.model_validate(node_data) + HumanInputNodeData.model_validate(normalize_human_input_node_data_for_graph(node_data)) except Exception as e: raise ValueError(f"Invalid HumanInput node data: {str(e)}") @@ -1511,38 +1530,48 @@ def _setup_variable_pool( user_id: str, user_inputs: Mapping[str, Any], workflow: Workflow, + node_id: str, node_type: NodeType, conversation_id: str, conversation_variables: list[VariableBase], ): # Only inject system variables for START node type. if is_start_node_type(node_type): - system_variable = SystemVariable( - user_id=user_id, - app_id=workflow.app_id, - timestamp=int(naive_utc_now().timestamp()), - workflow_id=workflow.id, - files=files or [], - workflow_execution_id=str(uuid.uuid4()), - ) + system_variable_values: dict[str, Any] = { + "user_id": user_id, + "app_id": workflow.app_id, + "timestamp": int(naive_utc_now().timestamp()), + "workflow_id": workflow.id, + "files": files or [], + "workflow_execution_id": str(uuid.uuid4()), + } - # Only add chatflow-specific variables for non-workflow types + # Only add chatflow-specific variables for non-workflow types. if workflow.type != WorkflowType.WORKFLOW: - system_variable.query = query - system_variable.conversation_id = conversation_id - system_variable.dialogue_count = 1 + system_variable_values.update( + { + "query": query, + "conversation_id": conversation_id, + "dialogue_count": 1, + } + ) + + system_variable = build_system_variables(system_variable_values) else: - system_variable = SystemVariable.default() + system_variable = default_system_variables() # init variable pool - variable_pool = VariablePool( - system_variables=system_variable, - user_inputs=user_inputs, - environment_variables=workflow.environment_variables, - # Based on the definition of `Variable`, - # `VariableBase` instances can be safely used as `Variable` since they are compatible. - conversation_variables=cast(list[Variable], conversation_variables), # + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variable, + environment_variables=workflow.environment_variables, + conversation_variables=cast(list[Variable], conversation_variables), + ), ) + if is_start_node_type(node_type): + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=user_inputs) return variable_pool @@ -1567,7 +1596,7 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia if variable_entity_type == VariableEntityType.FILE: if not isinstance(value, dict): raise ValueError(f"expected dict for file object, got {type(value)}") - return build_from_mapping(mapping=value, tenant_id=tenant_id) + return build_from_mapping(mapping=value, tenant_id=tenant_id, access_controller=_file_access_controller) elif variable_entity_type == VariableEntityType.FILE_LIST: if not isinstance(value, list): raise ValueError(f"expected list for file list object, got {type(value)}") @@ -1575,6 +1604,6 @@ def _rebuild_single_file(tenant_id: str, value: Any, variable_entity_type: Varia return [] if not isinstance(value[0], dict): raise ValueError(f"expected dict for first element in the file list, got {type(value)}") - return build_from_mappings(mappings=value, tenant_id=tenant_id) + return build_from_mappings(mappings=value, tenant_id=tenant_id, access_controller=_file_access_controller) else: raise Exception("unreachable") diff --git a/api/tasks/app_generate/workflow_execute_task.py b/api/tasks/app_generate/workflow_execute_task.py index 174aa50343..489467651d 100644 --- a/api/tasks/app_generate/workflow_execute_task.py +++ b/api/tasks/app_generate/workflow_execute_task.py @@ -7,6 +7,7 @@ from typing import Annotated, Any, TypeAlias, Union from celery import shared_task from flask import current_app, json +from graphon.runtime import GraphRuntimeState from pydantic import BaseModel, Discriminator, Field, Tag from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker @@ -21,7 +22,6 @@ from core.app.entities.app_invoke_entities import ( ) from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, WorkflowResumptionContext from core.repositories import DifyCoreRepositoryFactory -from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db from libs.flask_utils import set_login_user from models.account import Account @@ -239,13 +239,18 @@ def _resolve_user_for_run(session: Session, workflow_run: WorkflowRun) -> Accoun def _publish_streaming_response( - response_stream: Generator[str | Mapping[str, Any], None, None], workflow_run_id: str, app_mode: AppMode + response_stream: Generator[str | Mapping[str, Any] | BaseModel, None, None], + workflow_run_id: str, + app_mode: AppMode, ) -> None: topic = MessageBasedAppGenerator.get_response_topic(app_mode, workflow_run_id) for event in response_stream: try: - payload = json.dumps(event) - except TypeError: + if isinstance(event, BaseModel): + payload = json.dumps(event.model_dump(mode="json"), ensure_ascii=False) + else: + payload = json.dumps(event, ensure_ascii=False, default=str) + except (TypeError, ValueError): logger.exception("error while encoding event") continue diff --git a/api/tasks/async_workflow_tasks.py b/api/tasks/async_workflow_tasks.py index d247cf5cf7..0a73c91279 100644 --- a/api/tasks/async_workflow_tasks.py +++ b/api/tasks/async_workflow_tasks.py @@ -10,6 +10,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.runtime import GraphRuntimeState from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker @@ -21,7 +22,6 @@ from core.app.layers.timeslice_layer import TimeSliceLayer from core.app.layers.trigger_post_layer import TriggerPostLayer from core.db.session_factory import session_factory from core.repositories import DifyCoreRepositoryFactory -from dify_graph.runtime import GraphRuntimeState from extensions.ext_database import db from models.account import Account from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus diff --git a/api/tasks/batch_create_segment_to_index_task.py b/api/tasks/batch_create_segment_to_index_task.py index dd58378e0e..20335d9b9f 100644 --- a/api/tasks/batch_create_segment_to_index_task.py +++ b/api/tasks/batch_create_segment_to_index_task.py @@ -7,12 +7,12 @@ from pathlib import Path import click import pandas as pd from celery import shared_task +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy import func from core.db.session_factory import session_factory from core.model_manager import ModelManager from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client from extensions.ext_storage import storage from libs import helper @@ -121,7 +121,7 @@ def batch_create_segment_to_index_task( document_segments = [] embedding_model = None if dataset_config["indexing_technique"] == IndexTechniqueType.HIGH_QUALITY: - model_manager = ModelManager() + model_manager = ModelManager.for_tenant(tenant_id=dataset_config["tenant_id"]) embedding_model = model_manager.get_model_instance( tenant_id=dataset_config["tenant_id"], provider=dataset_config["embedding_model_provider"], diff --git a/api/tasks/enterprise_telemetry_task.py b/api/tasks/enterprise_telemetry_task.py new file mode 100644 index 0000000000..7d5ea7c0a5 --- /dev/null +++ b/api/tasks/enterprise_telemetry_task.py @@ -0,0 +1,52 @@ +"""Celery worker for enterprise metric/log telemetry events. + +This module defines the Celery task that processes telemetry envelopes +from the enterprise_telemetry queue. It deserializes envelopes and +dispatches them to the EnterpriseMetricHandler. +""" + +import json +import logging + +from celery import shared_task + +from enterprise.telemetry.contracts import TelemetryEnvelope +from enterprise.telemetry.metric_handler import EnterpriseMetricHandler + +logger = logging.getLogger(__name__) + + +@shared_task(queue="enterprise_telemetry") +def process_enterprise_telemetry(envelope_json: str) -> None: + """Process enterprise metric/log telemetry envelope. + + This task is enqueued by the TelemetryGateway for metric/log-only + events. It deserializes the envelope and dispatches to the handler. + + Best-effort processing: logs errors but never raises, to avoid + failing user requests due to telemetry issues. + + Args: + envelope_json: JSON-serialized TelemetryEnvelope. + """ + try: + # Deserialize envelope + envelope_dict = json.loads(envelope_json) + envelope = TelemetryEnvelope.model_validate(envelope_dict) + + # Process through handler + handler = EnterpriseMetricHandler() + handler.handle(envelope) + + logger.debug( + "Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s", + envelope.tenant_id, + envelope.event_id, + envelope.case, + ) + except Exception: + # Best-effort: log and drop on error, never fail user request + logger.warning( + "Failed to process enterprise telemetry envelope, dropping event", + exc_info=True, + ) diff --git a/api/tasks/human_input_timeout_tasks.py b/api/tasks/human_input_timeout_tasks.py index dd3b6a4530..ca73b4d374 100644 --- a/api/tasks/human_input_timeout_tasks.py +++ b/api/tasks/human_input_timeout_tasks.py @@ -2,13 +2,13 @@ import logging from datetime import timedelta from celery import shared_task +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from sqlalchemy import or_, select from sqlalchemy.orm import sessionmaker from configs import dify_config from core.repositories.human_input_repository import HumanInputFormSubmissionRepository -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from extensions.ext_database import db from extensions.ext_storage import storage from libs.datetime_utils import ensure_naive_utc, naive_utc_now diff --git a/api/tasks/mail_human_input_delivery_task.py b/api/tasks/mail_human_input_delivery_task.py index d241783359..a316eec7b9 100644 --- a/api/tasks/mail_human_input_delivery_task.py +++ b/api/tasks/mail_human_input_delivery_task.py @@ -6,13 +6,13 @@ from typing import Any import click from celery import shared_task +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import select from sqlalchemy.orm import Session, sessionmaker from configs import dify_config from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailDeliveryMethod -from dify_graph.runtime import GraphRuntimeState, VariablePool +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailDeliveryMethod from extensions.ext_database import db from extensions.ext_mail import mail from models.human_input import ( diff --git a/api/tasks/ops_trace_task.py b/api/tasks/ops_trace_task.py index 72e3b42ca7..c95b8db078 100644 --- a/api/tasks/ops_trace_task.py +++ b/api/tasks/ops_trace_task.py @@ -39,17 +39,36 @@ def process_trace_tasks(file_info): trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]] try: + trace_type = trace_info_info_map.get(trace_info_type) + if trace_type: + trace_info = trace_type(**trace_info) + + from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled + + if is_ee_telemetry_enabled(): + from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace + + try: + EnterpriseOtelTrace().trace(trace_info) + except Exception: + logger.exception("Enterprise trace failed for app_id: %s", app_id) + if trace_instance: with current_app.app_context(): - trace_type = trace_info_info_map.get(trace_info_type) - if trace_type: - trace_info = trace_type(**trace_info) trace_instance.trace(trace_info) + logger.info("Processing trace tasks success, app_id: %s", app_id) except Exception as e: - logger.info("error:\n\n\n%s\n\n\n\n", e) + logger.exception("Processing trace tasks failed, app_id: %s", app_id) failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}" redis_client.incr(failed_key) - logger.info("Processing trace tasks failed, app_id: %s", app_id) finally: - storage.delete(file_path) + try: + storage.delete(file_path) + except Exception as e: + logger.warning( + "Failed to delete trace file %s for app_id %s: %s", + file_path, + app_id, + e, + ) diff --git a/api/tasks/trigger_processing_tasks.py b/api/tasks/trigger_processing_tasks.py index f8c7964805..56626e372e 100644 --- a/api/tasks/trigger_processing_tasks.py +++ b/api/tasks/trigger_processing_tasks.py @@ -12,6 +12,7 @@ from datetime import UTC, datetime from typing import Any from celery import shared_task +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import func, select from sqlalchemy.orm import Session @@ -27,7 +28,6 @@ from core.trigger.entities.entities import TriggerProviderEntity from core.trigger.provider import PluginTriggerProviderController from core.trigger.trigger_manager import TriggerManager from core.workflow.nodes.trigger_plugin.entities import TriggerEventNodeData -from dify_graph.enums import WorkflowExecutionStatus from enums.quota_type import QuotaType, unlimited from models.enums import ( AppTriggerType, diff --git a/api/tasks/workflow_execution_tasks.py b/api/tasks/workflow_execution_tasks.py index f41118e592..0c7f74c180 100644 --- a/api/tasks/workflow_execution_tasks.py +++ b/api/tasks/workflow_execution_tasks.py @@ -9,11 +9,11 @@ import json import logging from celery import shared_task +from graphon.entities import WorkflowExecution +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory -from dify_graph.entities.workflow_execution import WorkflowExecution -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tasks/workflow_node_execution_tasks.py b/api/tasks/workflow_node_execution_tasks.py index 466ef6c858..f25ebe3bae 100644 --- a/api/tasks/workflow_node_execution_tasks.py +++ b/api/tasks/workflow_node_execution_tasks.py @@ -9,13 +9,13 @@ import json import logging from celery import shared_task +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, +) +from graphon.workflow_type_encoder import WorkflowRuntimeTypeConverter from sqlalchemy import select from core.db.session_factory import session_factory -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, -) -from dify_graph.workflow_type_encoder import WorkflowRuntimeTypeConverter from models import CreatorUserRole, WorkflowNodeExecutionModel from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -125,7 +125,7 @@ def _create_node_execution_from_domain( else: node_execution.execution_metadata = "{}" - node_execution.status = execution.status.value + node_execution.status = execution.status node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time node_execution.created_by_role = creator_user_role @@ -159,7 +159,7 @@ def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionMode node_execution.execution_metadata = "{}" # Update other fields - node_execution.status = execution.status.value + node_execution.status = execution.status node_execution.error = execution.error node_execution.elapsed_time = execution.elapsed_time node_execution.finished_at = execution.finished_at diff --git a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py index 4fdbb7d9f3..91245e879e 100644 --- a/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py +++ b/api/tests/integration_tests/core/datasource/test_datasource_manager_integration.py @@ -1,8 +1,9 @@ from collections.abc import Generator +from graphon.node_events import StreamCompletedEvent + from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage -from dify_graph.node_events import StreamCompletedEvent def _gen_var_stream() -> Generator[DatasourceMessage, None, None]: diff --git a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py index 3e79792b5b..3fdea10976 100644 --- a/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py +++ b/api/tests/integration_tests/core/workflow/nodes/datasource/test_datasource_node_integration.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamCompletedEvent + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent class _Seg: diff --git a/api/tests/integration_tests/factories/test_storage_key_loader.py b/api/tests/integration_tests/factories/test_storage_key_loader.py index db4bbc1ca1..c1bb8e1245 100644 --- a/api/tests/integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/integration_tests/factories/test_storage_key_loader.py @@ -4,9 +4,10 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session -from dify_graph.file import File, FileTransferMethod, FileType +from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -192,19 +197,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -313,7 +315,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -337,7 +339,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -364,6 +366,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py index 4e184c93fd..ce04a158a8 100644 --- a/api/tests/integration_tests/model_runtime/__mock/plugin_model.py +++ b/api/tests/integration_tests/model_runtime/__mock/plugin_model.py @@ -4,27 +4,27 @@ from collections.abc import Generator, Sequence from decimal import Decimal from json import dumps -from core.plugin.entities.plugin_daemon import PluginModelProviderEntity -from core.plugin.impl.model import PluginModelClient - # import monkeypatch -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import ( LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool -from dify_graph.model_runtime.entities.model_entities import ( +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, PromptMessage, PromptMessageTool +from graphon.model_runtime.entities.model_entities import ( AIModelEntity, FetchFrom, ModelFeature, ModelPropertyKey, ModelType, ) -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl.model import PluginModelClient class MockModelClass(PluginModelClient): diff --git a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py index 9d3a869691..5c6636f31e 100644 --- a/api/tests/integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/integration_tests/services/test_workflow_draft_variable_service.py @@ -3,14 +3,14 @@ import unittest import uuid import pytest +from graphon.nodes import BuiltinNodeTypes +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType +from graphon.variables.variables import StringVariable from sqlalchemy import delete from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType -from dify_graph.variables.variables import StringVariable +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from extensions.ext_database import db from extensions.ext_storage import storage from extensions.storage.storage_type import StorageType diff --git a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py index bc83c6cc12..38dc8bbb28 100644 --- a/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,10 +2,10 @@ import uuid from unittest.mock import patch import pytest +from graphon.variables.segments import StringSegment from sqlalchemy import delete from core.db.session_factory import session_factory -from dify_graph.variables.segments import StringSegment from extensions.storage.storage_type import StorageType from models import Tenant from models.enums import CreatorUserRole @@ -192,7 +192,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): tenant, app = app_and_tenant - from dify_graph.variables.types import SegmentType + from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now with session_factory.create_session() as session: @@ -423,7 +424,8 @@ class TestDeleteDraftVariablesSessionCommit: @pytest.fixture def setup_offload_test_data(self, app_and_tenant): """Create test data with offload files for session commit tests.""" - from dify_graph.variables.types import SegmentType + from graphon.variables.types import SegmentType + from libs.datetime_utils import naive_utc_now tenant, app = app_and_tenant diff --git a/api/tests/integration_tests/workflow/nodes/__mock/model.py b/api/tests/integration_tests/workflow/nodes/__mock/model.py index 5b0f86fed1..c0143faa85 100644 --- a/api/tests/integration_tests/workflow/nodes/__mock/model.py +++ b/api/tests/integration_tests/workflow/nodes/__mock/model.py @@ -1,11 +1,12 @@ from unittest.mock import MagicMock +from graphon.model_runtime.entities.model_entities import ModelType + from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, SystemConfiguration from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from core.plugin.impl.model_runtime_factory import create_plugin_model_provider_factory from models.provider import ProviderType @@ -15,7 +16,7 @@ def get_mocked_fetch_model_config( mode: str, credentials: dict, ): - model_provider_factory = ModelProviderFactory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") + model_provider_factory = create_plugin_model_provider_factory(tenant_id="9d2074fc-6f86-45a9-b09d-6ecc63b9056b") model_type_instance = model_provider_factory.get_model_type_instance(provider, ModelType.LLM) provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( diff --git a/api/tests/integration_tests/workflow/nodes/test_code.py b/api/tests/integration_tests/workflow/nodes/test_code.py index e3a2b6b866..ce0c8bf8ca 100644 --- a/api/tests/integration_tests/workflow/nodes/test_code.py +++ b/api/tests/integration_tests/workflow/nodes/test_code.py @@ -2,17 +2,17 @@ import time import uuid import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import NodeRunResult +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.code.code_node import CodeNode -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from tests.integration_tests.workflow.nodes.__mock.code_executor import setup_code_executor_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -44,7 +44,7 @@ def init_code_node(code_config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -172,7 +172,7 @@ def test_execute_code_output_validator(setup_code_executor_mock): result = node._run() assert isinstance(result, NodeRunResult) assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Output result must be a string, got int instead" + assert result.error == "Output result must be a string, got int instead." def test_execute_code_output_validator_depth(): diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index f885f69e55..ce18486faf 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -3,18 +3,19 @@ import uuid from urllib.parse import urlencode import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.graph import Graph +from graphon.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file.file_manager import file_manager -from dify_graph.graph import Graph -from dify_graph.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables from tests.integration_tests.workflow.nodes.__mock.http import setup_http_mock from tests.workflow_test_utils import build_test_graph_init_params @@ -54,7 +55,7 @@ def init_http_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -81,6 +82,7 @@ def init_http_node(config: dict): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) return node @@ -189,20 +191,21 @@ def test_custom_authorization_header(setup_http_mock): @pytest.mark.parametrize("setup_http_mock", [["none"]], indirect=True) def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): """Test: In custom authentication mode, when the api_key is empty, AuthorizationConfigError should be raised.""" - from dify_graph.enums import BuiltinNodeTypes - from dify_graph.nodes.http_request.entities import ( + from graphon.enums import BuiltinNodeTypes + from graphon.nodes.http_request.entities import ( HttpRequestNodeAuthorization, HttpRequestNodeData, HttpRequestNodeTimeout, ) - from dify_graph.nodes.http_request.exc import AuthorizationConfigError - from dify_graph.nodes.http_request.executor import Executor - from dify_graph.runtime import VariablePool - from dify_graph.system_variable import SystemVariable + from graphon.nodes.http_request.exc import AuthorizationConfigError + from graphon.nodes.http_request.executor import Executor + from graphon.runtime import VariablePool + + from core.workflow.system_variables import build_system_variables # Create variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="test", files=[]), + system_variables=build_system_variables(user_id="test", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -700,7 +703,7 @@ def test_nested_object_variable_selector(setup_http_mock): # Create independent variable pool for this test only variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -728,6 +731,7 @@ def test_nested_object_variable_selector(setup_http_mock): http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(init_params.run_context), ) result = node._run() diff --git a/api/tests/integration_tests/workflow/nodes/test_llm.py b/api/tests/integration_tests/workflow/nodes/test_llm.py index d628348f1e..f0f3fcead1 100644 --- a/api/tests/integration_tests/workflow/nodes/test_llm.py +++ b/api/tests/integration_tests/workflow/nodes/test_llm.py @@ -4,16 +4,19 @@ import uuid from collections.abc import Generator from unittest.mock import MagicMock, patch +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.node import LLMNode +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.nodes.protocols import HttpClientProtocol +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.llm_generator.output_parser.structured_output import _parse_structured_output from core.model_manager import ModelInstance -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.node_events import StreamCompletedEvent -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -51,7 +54,7 @@ def init_llm_node(config: dict) -> LLMNode: # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", app_id=app_id, workflow_id=workflow_id, @@ -66,6 +69,11 @@ def init_llm_node(config: dict) -> LLMNode: variable_pool.add(["abc", "output"], "sunny") graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + prompt_message_serializer = MagicMock(spec=PromptMessageSerializerProtocol) + prompt_message_serializer.serialize.side_effect = lambda *, model_mode, prompt_messages: [ + message.model_dump(mode="json") for message in prompt_messages + ] + llm_file_saver = MagicMock(spec=LLMFileSaver) node = LLMNode( id=str(uuid.uuid4()), @@ -75,7 +83,8 @@ def init_llm_node(config: dict) -> LLMNode: credentials_provider=MagicMock(spec=CredentialsProvider), model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), - template_renderer=MagicMock(spec=TemplateRenderer), + llm_file_saver=llm_file_saver, + prompt_message_serializer=prompt_message_serializer, http_client=MagicMock(spec=HttpClientProtocol), ) @@ -115,8 +124,8 @@ def test_execute_llm(): from decimal import Decimal from unittest.mock import MagicMock - from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -159,8 +168,8 @@ def test_execute_llm(): return mock_model_instance # Mock fetch_prompt_messages to avoid database calls - def mock_fetch_prompt_messages_1(*_args, **_kwargs): - from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + def mock_fetch_prompt_messages_1(**_kwargs): + from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), @@ -231,8 +240,8 @@ def test_execute_llm_with_jinja2(): from decimal import Decimal from unittest.mock import MagicMock - from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage # Create mock model instance mock_model_instance = MagicMock(spec=ModelInstance) @@ -276,7 +285,7 @@ def test_execute_llm_with_jinja2(): # Mock fetch_prompt_messages to avoid database calls def mock_fetch_prompt_messages_2(**_kwargs): - from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage + from graphon.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage return [ SystemPromptMessage(content="you are a helpful assistant. today's weather is sunny."), diff --git a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py index 62d9af0196..3bf44df349 100644 --- a/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py +++ b/api/tests/integration_tests/workflow/nodes/test_parameter_extractor.py @@ -3,14 +3,16 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities import AssistantPromptMessage, UserPromptMessage +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.model_manager import ModelInstance -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities import AssistantPromptMessage, UserPromptMessage -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyPromptMessageSerializer +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from tests.integration_tests.workflow.nodes.__mock.model import get_mocked_fetch_model_instance from tests.workflow_test_utils import build_test_graph_init_params @@ -56,7 +58,7 @@ def init_parameter_extractor_node(config: dict, memory=None): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="aaa", files=[], query="what's the weather in SF", conversation_id="abababa" ), user_inputs={}, @@ -77,6 +79,7 @@ def init_parameter_extractor_node(config: dict, memory=None): model_factory=MagicMock(spec=ModelFactory), model_instance=MagicMock(spec=ModelInstance), memory=memory, + prompt_message_serializer=DifyPromptMessageSerializer(), ) return node diff --git a/api/tests/integration_tests/workflow/nodes/test_template_transform.py b/api/tests/integration_tests/workflow/nodes/test_template_transform.py index 7bb4f905c3..2d728569be 100644 --- a/api/tests/integration_tests/workflow/nodes/test_template_transform.py +++ b/api/tests/integration_tests/workflow/nodes/test_template_transform.py @@ -1,14 +1,15 @@ import time import uuid +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -66,7 +67,7 @@ def test_execute_template_transform(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -90,7 +91,7 @@ def test_execute_template_transform(): config=config, graph_init_params=init_params, graph_runtime_state=graph_runtime_state, - template_renderer=_SimpleJinja2Renderer(), + jinja2_template_renderer=_SimpleJinja2Renderer(), ) # execute node diff --git a/api/tests/integration_tests/workflow/nodes/test_tool.py b/api/tests/integration_tests/workflow/nodes/test_tool.py index 818ae46625..750ced7075 100644 --- a/api/tests/integration_tests/workflow/nodes/test_tool.py +++ b/api/tests/integration_tests/workflow/nodes/test_tool.py @@ -2,16 +2,18 @@ import time import uuid from unittest.mock import MagicMock, patch +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.node_events import StreamCompletedEvent +from graphon.nodes.protocols import ToolFileManagerProtocol +from graphon.nodes.tool.tool_node import ToolNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.tools.utils.configuration import ToolParameterConfigurationManager from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.node_events import StreamCompletedEvent -from dify_graph.nodes.protocols import ToolFileManagerProtocol -from dify_graph.nodes.tool.tool_node import ToolNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -40,7 +42,7 @@ def init_tool_node(config: dict): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -64,6 +66,7 @@ def init_tool_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=DifyToolNodeRuntime(init_params.run_context), ) return node diff --git a/api/tests/test_containers_integration_tests/conftest.py b/api/tests/test_containers_integration_tests/conftest.py index ef0ca4232d..be8a1c6aab 100644 --- a/api/tests/test_containers_integration_tests/conftest.py +++ b/api/tests/test_containers_integration_tests/conftest.py @@ -32,6 +32,10 @@ from extensions.ext_database import db # Configure logging for test containers logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s") logger = logging.getLogger(__name__) +_TEST_SANDBOX_IMAGE = os.getenv("TEST_SANDBOX_IMAGE", "langgenius/dify-sandbox:0.2.12") + +DEFAULT_SANDBOX_TEST_IMAGE = "langgenius/dify-sandbox:0.2.14" +SANDBOX_TEST_IMAGE_ENV = "DIFY_SANDBOX_TEST_IMAGE" class _CloserProtocol(Protocol): @@ -163,11 +167,11 @@ class DifyTestContainers: wait_for_logs(self.redis, "Ready to accept connections", timeout=30) logger.info("Redis container is ready and accepting connections") - # Start Dify Sandbox container for code execution environment - # Dify Sandbox provides a secure environment for executing user code - # Use pinned version 0.2.12 to match production docker-compose configuration + # Start Dify Sandbox container for code execution environment. + # Default to the production-pinned image while allowing local overrides for debugging. logger.info("Initializing Dify Sandbox container...") - self.dify_sandbox = DockerContainer(image="langgenius/dify-sandbox:0.2.12").with_network(self.network) + sandbox_image = os.getenv(SANDBOX_TEST_IMAGE_ENV, DEFAULT_SANDBOX_TEST_IMAGE) + self.dify_sandbox = DockerContainer(image=sandbox_image).with_network(self.network) self.dify_sandbox.with_exposed_ports(8194) self.dify_sandbox.env = { "API_KEY": "test_api_key", @@ -177,7 +181,12 @@ class DifyTestContainers: sandbox_port = self.dify_sandbox.get_exposed_port(8194) os.environ["CODE_EXECUTION_ENDPOINT"] = f"http://{sandbox_host}:{sandbox_port}" os.environ["CODE_EXECUTION_API_KEY"] = "test_api_key" - logger.info("Dify Sandbox container started successfully - Host: %s, Port: %s", sandbox_host, sandbox_port) + logger.info( + "Dify Sandbox container started successfully - Image: %s Host: %s, Port: %s", + sandbox_image, + sandbox_host, + sandbox_port, + ) # Wait for Dify Sandbox to be ready logger.info("Waiting for Dify Sandbox to be ready to accept connections...") @@ -187,7 +196,7 @@ class DifyTestContainers: # Start Dify Plugin Daemon container for plugin management # Dify Plugin Daemon provides plugin lifecycle management and execution logger.info("Initializing Dify Plugin Daemon container...") - self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.4-local").with_network( + self.dify_plugin_daemon = DockerContainer(image="langgenius/dify-plugin-daemon:0.5.3-local").with_network( self.network ) self.dify_plugin_daemon.with_exposed_ports(5002) diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py index 4f606dccb8..5cc458fe2e 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_chat_conversation_status_count_api.py @@ -4,11 +4,11 @@ import json import uuid from flask.testing import FlaskClient +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session from configs import dify_config from constants import HEADER_NAME_CSRF_TOKEN -from dify_graph.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from libs.token import _real_cookie_name, generate_csrf_token from models import Account, DifySetup, Tenant, TenantAccountJoin diff --git a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py index f037ad77c0..8ddf867370 100644 --- a/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py +++ b/api/tests/test_containers_integration_tests/controllers/console/app/test_workflow_draft_variable.py @@ -3,11 +3,11 @@ import uuid from flask.testing import FlaskClient +from graphon.variables.segments import StringSegment from sqlalchemy import select from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID -from dify_graph.variables.segments import StringSegment +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID from factories.variable_factory import segment_to_variable from models import Workflow from models.model import AppMode diff --git a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py index 96fb7ea293..2b4c1b59ab 100644 --- a/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/test_containers_integration_tests/core/app/layers/test_pause_state_persist_layer.py @@ -22,6 +22,13 @@ import uuid from time import time import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import WorkflowExecutionStatus +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import GraphRunPausedEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session @@ -31,16 +38,7 @@ from core.app.layers.pause_state_persist_layer import ( PauseStatePersistenceLayer, WorkflowResumptionContext, ) -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.graph_engine.entities.commands import GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from dify_graph.graph_events.graph import GraphRunPausedEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.runtime.read_only_wrappers import ReadOnlyGraphRuntimeStateWrapper -from dify_graph.runtime.variable_pool import SystemVariable, VariablePool +from core.workflow.system_variables import build_system_variables from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account @@ -212,7 +210,7 @@ class TestPauseStatePersistenceLayerTestContainers: execution_id = workflow_run_id or getattr(self, "test_workflow_run_id", None) or str(uuid.uuid4()) # Create variable pool - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=execution_id)) + variable_pool = VariablePool(system_variables=build_system_variables(workflow_execution_id=execution_id)) if variables: for (node_id, var_key), value in variables.items(): variable_pool.add([node_id, var_key], value) @@ -544,7 +542,7 @@ class TestPauseStatePersistenceLayerTestContainers: layer.initialize(graph_runtime_state, command_channel) # Import other event types - from dify_graph.graph_events.graph import ( + from graphon.graph_events import ( GraphRunFailedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, diff --git a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py index 9d0fad4b12..13caad799e 100644 --- a/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/test_containers_integration_tests/core/repositories/test_human_input_form_repository_impl.py @@ -4,23 +4,20 @@ from __future__ import annotations from uuid import uuid4 +from graphon.nodes.human_input.entities import FormDefinition, HumanInputNodeData, UserAction from sqlalchemy import Engine, select from sqlalchemy.orm import Session -from core.repositories.human_input_repository import HumanInputFormRepositoryImpl -from dify_graph.nodes.human_input.entities import ( +from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl +from core.workflow.human_input_compat import ( DeliveryChannelConfig, EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.repositories.human_input_form_repository import FormCreateParams from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.human_input import ( EmailExternalRecipientPayload, @@ -68,7 +65,6 @@ def _build_form_params(delivery_methods: list[DeliveryChannelConfig]) -> FormCre user_actions=[UserAction(id="approve", title="Approve")], ) return FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=form_config, @@ -84,7 +80,7 @@ def _build_email_delivery( ) -> EmailDeliveryMethod: return EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(whole_workspace=whole_workspace, items=recipients), + recipients=EmailRecipients(include_bound_group=whole_workspace, items=recipients), subject="Approval Needed", body="Please review", ) @@ -100,7 +96,7 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["member1@example.com", "member2@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[_build_email_delivery(whole_workspace=True, recipients=[])], ) @@ -129,13 +125,13 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["primary@example.com", "secondary@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = _build_form_params( delivery_methods=[ _build_email_delivery( whole_workspace=False, recipients=[ - MemberRecipient(user_id=members[0].id), + MemberRecipient(reference_id=members[0].id), ExternalRecipient(email="external@example.com"), ], ) @@ -173,10 +169,9 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["prefill@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) resolved_values = {"greeting": "Hello!"} params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( @@ -210,9 +205,8 @@ class TestHumanInputFormRepositoryImplWithContainers: member_emails=["ui@example.com"], ) - repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repository = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=str(uuid4())) params = FormCreateParams( - app_id=str(uuid4()), workflow_execution_id=str(uuid4()), node_id="human-input-node", form_config=HumanInputNodeData( diff --git a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py index 9733735df3..0a9b476afc 100644 --- a/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py +++ b/api/tests/test_containers_integration_tests/core/workflow/test_human_input_resume_node_execution.py @@ -4,28 +4,29 @@ from datetime import timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import WorkflowType +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool from sqlalchemy import delete, select from sqlalchemy.orm import Session from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.workflow.layers import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.repositories.human_input_repository import HumanInputFormEntity, HumanInputFormRepository from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository -from dify_graph.enums import WorkflowType -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account from models.account import Tenant, TenantAccountJoin, TenantAccountRole @@ -39,7 +40,7 @@ def _mock_form_repository_without_submission() -> HumanInputFormRepository: repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = False @@ -52,7 +53,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos repo = MagicMock(spec=HumanInputFormRepository) form_entity = MagicMock(spec=HumanInputFormEntity) form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" + form_entity.submission_token = "test-form-token" form_entity.recipients = [] form_entity.rendered_content = "rendered" form_entity.submitted = True @@ -66,7 +67,7 @@ def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepos def _build_runtime_state(workflow_execution_id: str, app_id: str, workflow_id: str, user_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( workflow_execution_id=workflow_execution_id, app_id=app_id, workflow_id=workflow_id, @@ -120,6 +121,7 @@ def _build_graph( graph_init_params=params, graph_runtime_state=runtime_state, form_repository=form_repository, + runtime=DifyHumanInputNodeRuntime(params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py index 8e70fc0bb0..cc72dc1cf3 100644 --- a/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py +++ b/api/tests/test_containers_integration_tests/factories/test_storage_key_loader.py @@ -4,9 +4,10 @@ from unittest.mock import patch from uuid import uuid4 import pytest +from graphon.file import File, FileTransferMethod, FileType from sqlalchemy.orm import Session -from dify_graph.file import File, FileTransferMethod, FileType +from core.app.file_access import DatabaseFileAccessController from extensions.ext_database import db from extensions.storage.storage_type import StorageType from factories.file_factory import StorageKeyLoader @@ -35,7 +36,11 @@ class TestStorageKeyLoader(unittest.TestCase): self.test_tool_files = [] # Create StorageKeyLoader instance - self.loader = StorageKeyLoader(self.session, self.tenant_id) + self.loader = StorageKeyLoader( + self.session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) def tearDown(self): """Clean up test data after each test method.""" @@ -193,19 +198,16 @@ class TestStorageKeyLoader(unittest.TestCase): # Should not raise any exceptions self.loader.load_storage_keys([]) - def test_load_storage_keys_tenant_mismatch(self): - """Test tenant_id validation.""" - # Create file with different tenant_id + def test_load_storage_keys_ignores_legacy_file_tenant_id(self): + """Legacy file tenant_id should not override the loader tenant scope.""" upload_file = self._create_upload_file() file = self._create_file( related_id=upload_file.id, transfer_method=FileTransferMethod.LOCAL_FILE, tenant_id=str(uuid4()) ) - # Should raise ValueError for tenant mismatch - with pytest.raises(ValueError) as context: - self.loader.load_storage_keys([file]) + self.loader.load_storage_keys([file]) - assert "invalid file, expected tenant_id" in str(context.value) + assert file._storage_key == upload_file.key def test_load_storage_keys_missing_file_id(self): """Test with None file.related_id.""" @@ -314,7 +316,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) # Current tenant's file should still work self.loader.load_storage_keys([file_current]) @@ -338,7 +340,7 @@ class TestStorageKeyLoader(unittest.TestCase): with pytest.raises(ValueError) as context: self.loader.load_storage_keys([file_current, file_other]) - assert "invalid file, expected tenant_id" in str(context.value) + assert "Upload file not found for id:" in str(context.value) def test_load_storage_keys_duplicate_file_ids(self): """Test handling of duplicate file IDs in the batch.""" @@ -365,6 +367,10 @@ class TestStorageKeyLoader(unittest.TestCase): # Create loader with different session (same underlying connection) with Session(bind=db.engine) as other_session: - other_loader = StorageKeyLoader(other_session, self.tenant_id) + other_loader = StorageKeyLoader( + other_session, + self.tenant_id, + access_controller=DatabaseFileAccessController(), + ) with pytest.raises(ValueError): other_loader.load_storage_keys([file]) diff --git a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py index fb8d1808f9..b745aed141 100644 --- a/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py +++ b/api/tests/test_containers_integration_tests/helpers/execution_extra_content.py @@ -1,11 +1,13 @@ from __future__ import annotations from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import timedelta from decimal import Decimal from uuid import uuid4 -from dify_graph.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.entities import FormDefinition, UserAction + +from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import HumanInputContent @@ -117,7 +119,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: inputs=[], user_actions=[UserAction(id=action_id, title=action_text)], rendered_content="Rendered block", - expiration_time=datetime.utcnow() + timedelta(days=1), + expiration_time=naive_utc_now() + timedelta(days=1), node_title=node_title, display_in_ui=True, ) @@ -129,7 +131,7 @@ def create_human_input_message_fixture(db_session) -> HumanInputMessageFixture: form_definition=form_definition.model_dump_json(), rendered_content="Rendered block", status=HumanInputFormStatus.SUBMITTED, - expiration_time=datetime.utcnow() + timedelta(days=1), + expiration_time=naive_utc_now() + timedelta(days=1), selected_action_id=action_id, ) db_session.add(form) diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py index 458862b0ec..a68b3a08c7 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_node_execution_repository.py @@ -5,10 +5,10 @@ from __future__ import annotations from datetime import timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py index 49b370990a..d28cfda159 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_api_workflow_run_repository.py @@ -2,30 +2,27 @@ from __future__ import annotations -import secrets from dataclasses import dataclass, field from datetime import datetime, timedelta from unittest.mock import Mock from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.entities.pause_reason import HumanInputRequired, PauseReasonType +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormDefinition, FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType, HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.entities.pause_reason import HumanInputRequired, PauseReasonType -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import FormDefinition, FormInput, UserAction -from dify_graph.nodes.human_input.enums import DeliveryMethodType, FormInputType, HumanInputFormStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.human_input import ( - BackstageRecipientPayload, HumanInputDelivery, HumanInputForm, HumanInputFormRecipient, - RecipientType, ) from models.workflow import WorkflowAppLog, WorkflowAppLogCreatedFrom, WorkflowPause, WorkflowPauseReason, WorkflowRun from repositories.entities.workflow_pause import WorkflowPauseEntity @@ -636,12 +633,12 @@ class TestPrivateWorkflowPauseEntity: class TestBuildHumanInputRequiredReason: """Integration tests for _build_human_input_required_reason using real DB models.""" - def test_prefers_backstage_token_when_available( + def test_builds_reason_from_form_definition( self, db_session_with_containers: Session, test_scope: _TestScope, ) -> None: - """Use backstage token when multiple recipient types may exist.""" + """Build the graph pause reason from the stored form definition.""" expiration_time = naive_utc_now() form_definition = FormDefinition( @@ -668,25 +665,6 @@ class TestBuildHumanInputRequiredReason: db_session_with_containers.add(form_model) db_session_with_containers.flush() - delivery = HumanInputDelivery( - form_id=form_model.id, - delivery_method_type=DeliveryMethodType.WEBAPP, - channel_payload="{}", - ) - db_session_with_containers.add(delivery) - db_session_with_containers.flush() - - access_token = secrets.token_urlsafe(8) - recipient = HumanInputFormRecipient( - form_id=form_model.id, - delivery_id=delivery.id, - recipient_type=RecipientType.BACKSTAGE, - recipient_payload=BackstageRecipientPayload().model_dump_json(), - access_token=access_token, - ) - db_session_with_containers.add(recipient) - db_session_with_containers.flush() - # Create a pause so the reason has a valid pause_id workflow_run = _create_workflow_run( db_session_with_containers, @@ -716,13 +694,12 @@ class TestBuildHumanInputRequiredReason: # Refresh to ensure we have DB-round-tripped objects db_session_with_containers.refresh(form_model) db_session_with_containers.refresh(reason_model) - db_session_with_containers.refresh(recipient) - reason = _build_human_input_required_reason(reason_model, form_model, [recipient]) + reason = _build_human_input_required_reason(reason_model, form_model) assert isinstance(reason, HumanInputRequired) - assert reason.form_token == access_token assert reason.node_title == "Ask Name" assert reason.form_content == "content" assert reason.inputs[0].output_variable_name == "name" assert reason.actions[0].id == "approve" + assert reason.resolved_default_values == {"name": "Alice"} diff --git a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py index ed998c9ed0..7f44eb6ca3 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_sqlalchemy_execution_extra_content_repository.py @@ -7,16 +7,17 @@ from __future__ import annotations from collections.abc import Generator from dataclasses import dataclass -from datetime import datetime, timedelta +from datetime import timedelta from decimal import Decimal from uuid import uuid4 import pytest +from graphon.nodes.human_input.entities import FormDefinition, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus from sqlalchemy import Engine, delete, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.nodes.human_input.entities import FormDefinition, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus +from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import ConversationFromSource, InvokeFrom from models.execution_extra_content import ExecutionExtraContent, HumanInputContent @@ -174,7 +175,7 @@ def _create_submitted_form( action_title: str = "Approve", node_title: str = "Approval", ) -> HumanInputForm: - expiration_time = datetime.utcnow() + timedelta(days=1) + expiration_time = naive_utc_now() + timedelta(days=1) form_definition = FormDefinition( form_content="content", inputs=[], @@ -207,7 +208,7 @@ def _create_waiting_form( workflow_run_id: str, default_values: dict | None = None, ) -> HumanInputForm: - expiration_time = datetime.utcnow() + timedelta(days=1) + expiration_time = naive_utc_now() + timedelta(days=1) form_definition = FormDefinition( form_content="content", inputs=[], @@ -270,7 +271,7 @@ def _create_recipient( def _create_delivery(session: Session, *, form_id: str) -> HumanInputDelivery: - from dify_graph.nodes.human_input.enums import DeliveryMethodType + from core.workflow.human_input_compat import DeliveryMethodType from models.human_input import ConsoleDeliveryPayload delivery = HumanInputDelivery( diff --git a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py index 1568d5d65c..c5e9201ee3 100644 --- a/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py +++ b/api/tests/test_containers_integration_tests/repositories/test_workflow_run_repository.py @@ -7,12 +7,12 @@ from datetime import timedelta from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine, delete from sqlalchemy import exc as sa_exc from sqlalchemy.orm import Session, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowRun, WorkflowType diff --git a/api/dify_graph/model_runtime/__init__.py b/api/tests/test_containers_integration_tests/services/auth/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/__init__.py rename to api/tests/test_containers_integration_tests/services/auth/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py new file mode 100644 index 0000000000..177fb95ff3 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/auth/test_api_key_auth_service.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import json +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_service import ApiKeyAuthService + + +class TestApiKeyAuthService: + @pytest.fixture + def tenant_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def category(self) -> str: + return "search" + + @pytest.fixture + def provider(self) -> str: + return "google" + + @pytest.fixture + def mock_credentials(self) -> dict: + return {"auth_type": "api_key", "config": {"api_key": "test_secret_key_123"}} + + @pytest.fixture + def mock_args(self, category, provider, mock_credentials) -> dict: + return {"category": category, "provider": provider, "credentials": mock_credentials} + + def _create_binding(self, db_session, *, tenant_id, category, provider, credentials=None, disabled=False): + binding = DataSourceApiKeyAuthBinding( + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=json.dumps(credentials, ensure_ascii=False) if credentials else None, + disabled=disabled, + ) + db_session.add(binding) + db_session.commit() + return binding + + def test_get_provider_auth_list_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + self._create_binding(db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + assert len(result) >= 1 + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert len(tenant_results) == 1 + assert tenant_results[0].provider == provider + + def test_get_provider_auth_list_empty(self, flask_app_with_containers, db_session_with_containers, tenant_id): + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert tenant_results == [] + + def test_get_provider_auth_list_filters_disabled( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + self._create_binding( + db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider, disabled=True + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_provider_auth_list(tenant_id) + + tenant_results = [r for r in result if r.tenant_id == tenant_id] + assert tenant_results == [] + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_success( + self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + mock_factory.assert_called_once() + mock_auth_instance.validate_credentials.assert_called_once() + mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, "test_secret_key_123") + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() + assert len(bindings) == 1 + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_validation_failed( + self, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = False + mock_factory.return_value = mock_auth_instance + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id).all() + assert len(bindings) == 0 + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encrypts_api_key( + self, mock_encrypter, mock_factory, flask_app_with_containers, db_session_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_test_key_123" + + original_key = mock_args["credentials"]["config"]["api_key"] + + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + assert mock_args["credentials"]["config"]["api_key"] == "encrypted_test_key_123" + assert mock_args["credentials"]["config"]["api_key"] != original_key + mock_encrypter.encrypt_token.assert_called_once_with(tenant_id, original_key) + + def test_get_auth_credentials_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider, mock_credentials + ): + self._create_binding( + db_session_with_containers, + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=mock_credentials, + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result == mock_credentials + + def test_get_auth_credentials_not_found( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result is None + + def test_get_auth_credentials_json_parsing( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + special_credentials = {"auth_type": "api_key", "config": {"api_key": "key_with_中文_and_special_chars_!@#$%"}} + self._create_binding( + db_session_with_containers, + tenant_id=tenant_id, + category=category, + provider=provider, + credentials=special_credentials, + ) + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id, category, provider) + + assert result == special_credentials + assert result["config"]["api_key"] == "key_with_中文_and_special_chars_!@#$%" + + def test_delete_provider_auth_success( + self, flask_app_with_containers, db_session_with_containers, tenant_id, category, provider + ): + binding = self._create_binding( + db_session_with_containers, tenant_id=tenant_id, category=category, provider=provider + ) + binding_id = binding.id + db_session_with_containers.expire_all() + + ApiKeyAuthService.delete_provider_auth(tenant_id, binding_id) + + db_session_with_containers.expire_all() + remaining = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(id=binding_id).first() + assert remaining is None + + def test_delete_provider_auth_not_found(self, flask_app_with_containers, db_session_with_containers, tenant_id): + # Should not raise when binding not found + ApiKeyAuthService.delete_provider_auth(tenant_id, str(uuid4())) + + def test_validate_api_key_auth_args_success(self, mock_args): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_category(self, mock_args): + del mock_args["category"] + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_category(self, mock_args): + mock_args["category"] = "" + with pytest.raises(ValueError, match="category is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_provider(self, mock_args): + del mock_args["provider"] + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_provider(self, mock_args): + mock_args["provider"] = "" + with pytest.raises(ValueError, match="provider is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_credentials(self, mock_args): + del mock_args["credentials"] + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_credentials(self, mock_args): + mock_args["credentials"] = None + with pytest.raises(ValueError, match="credentials is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_invalid_credentials_type(self, mock_args): + mock_args["credentials"] = "not_a_dict" + with pytest.raises(ValueError, match="credentials must be a dictionary"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_missing_auth_type(self, mock_args): + del mock_args["credentials"]["auth_type"] + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + def test_validate_api_key_auth_args_empty_auth_type(self, mock_args): + mock_args["credentials"]["auth_type"] = "" + with pytest.raises(ValueError, match="auth_type is required"): + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + @pytest.mark.parametrize( + "malicious_input", + [ + "", + "'; DROP TABLE users; --", + "../../../etc/passwd", + "\\x00\\x00", + "A" * 10000, + ], + ) + def test_validate_api_key_auth_args_malicious_input(self, malicious_input, mock_args): + mock_args["category"] = malicious_input + ApiKeyAuthService.validate_api_key_auth_args(mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_database_error_handling( + self, mock_encrypter, mock_factory, flask_app_with_containers, tenant_id, mock_args + ): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.return_value = "encrypted_key" + + with patch("services.auth.api_key_auth_service.db.session") as mock_session: + mock_session.commit.side_effect = Exception("Database error") + with pytest.raises(Exception, match="Database error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + def test_create_provider_auth_factory_exception(self, mock_factory, tenant_id, mock_args): + mock_factory.side_effect = Exception("Factory error") + with pytest.raises(Exception, match="Factory error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") + @patch("services.auth.api_key_auth_service.encrypter") + def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, tenant_id, mock_args): + mock_auth_instance = Mock() + mock_auth_instance.validate_credentials.return_value = True + mock_factory.return_value = mock_auth_instance + mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") + with pytest.raises(Exception, match="Encryption error"): + ApiKeyAuthService.create_provider_auth(tenant_id, mock_args) + + def test_validate_api_key_auth_args_none_input(self): + with pytest.raises(TypeError): + ApiKeyAuthService.validate_api_key_auth_args(None) + + def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self, mock_args): + mock_args["credentials"]["auth_type"] = ["api_key"] + ApiKeyAuthService.validate_api_key_auth_args(mock_args) diff --git a/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py new file mode 100644 index 0000000000..dc4c0fda1d --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/auth/test_auth_integration.py @@ -0,0 +1,264 @@ +""" +API Key Authentication System Integration Tests +""" + +from __future__ import annotations + +from concurrent.futures import ThreadPoolExecutor +from unittest.mock import Mock, patch +from uuid import uuid4 + +import httpx +import pytest + +from models.source import DataSourceApiKeyAuthBinding +from services.auth.api_key_auth_factory import ApiKeyAuthFactory +from services.auth.api_key_auth_service import ApiKeyAuthService +from services.auth.auth_type import AuthType + + +class TestAuthIntegration: + @pytest.fixture + def tenant_id_1(self) -> str: + return str(uuid4()) + + @pytest.fixture + def tenant_id_2(self) -> str: + return str(uuid4()) + + @pytest.fixture + def category(self) -> str: + return "search" + + @pytest.fixture + def firecrawl_credentials(self) -> dict: + return {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}} + + @pytest.fixture + def jina_credentials(self) -> dict: + return {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}} + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + def test_end_to_end_auth_flow( + self, + mock_encrypt, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_fc_test_key_123" + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + mock_http.assert_called_once() + call_args = mock_http.call_args + assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0] + assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123" + + mock_encrypt.assert_called_once_with(tenant_id_1, "fc_test_key_123") + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() + assert len(bindings) == 1 + assert bindings[0].provider == AuthType.FIRECRAWL + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_cross_component_integration(self, mock_http, firecrawl_credentials): + mock_http.return_value = self._create_success_response() + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials) + result = factory.validate_credentials() + + assert result is True + mock_http.assert_called_once() + + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.jina.jina.httpx.post") + def test_multi_tenant_isolation( + self, + mock_jina_http, + mock_fc_http, + mock_encrypt, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + tenant_id_2, + category, + firecrawl_credentials, + jina_credentials, + ): + mock_fc_http.return_value = self._create_success_response() + mock_jina_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_key" + + args1 = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args1) + + args2 = {"category": category, "provider": AuthType.JINA, "credentials": jina_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_2, args2) + + db_session_with_containers.expire_all() + + result1 = ApiKeyAuthService.get_provider_auth_list(tenant_id_1) + result2 = ApiKeyAuthService.get_provider_auth_list(tenant_id_2) + + assert len(result1) == 1 + assert result1[0].tenant_id == tenant_id_1 + assert len(result2) == 1 + assert result2[0].tenant_id == tenant_id_2 + + def test_cross_tenant_access_prevention( + self, flask_app_with_containers, db_session_with_containers, tenant_id_2, category + ): + result = ApiKeyAuthService.get_auth_credentials(tenant_id_2, category, AuthType.FIRECRAWL) + + assert result is None + + def test_sensitive_data_protection(self): + credentials_with_secrets = { + "auth_type": "bearer", + "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"}, + } + + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets) + factory_str = str(factory) + + assert "super_secret_key_do_not_log" not in factory_str + assert "another_secret" not in factory_str + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token", return_value="encrypted_key") + def test_concurrent_creation_safety( + self, + mock_encrypt, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + app = flask_app_with_containers + mock_http.return_value = self._create_success_response() + + results = [] + exceptions = [] + + def create_auth(): + try: + with app.app_context(): + thread_args = { + "category": category, + "provider": AuthType.FIRECRAWL, + "credentials": {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}}, + } + ApiKeyAuthService.create_provider_auth(tenant_id_1, thread_args) + results.append("success") + except Exception as e: + exceptions.append(e) + + with ThreadPoolExecutor(max_workers=5) as executor: + futures = [executor.submit(create_auth) for _ in range(5)] + for future in futures: + future.result() + + assert len(results) == 5 + assert len(exceptions) == 0 + + @pytest.mark.parametrize( + "invalid_input", + [ + None, + {}, + {"auth_type": "bearer"}, + {"auth_type": "bearer", "config": {}}, + ], + ) + def test_invalid_input_boundary(self, invalid_input): + with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): + ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_http_error_handling(self, mock_http, firecrawl_credentials): + mock_response = Mock() + mock_response.status_code = 401 + mock_response.text = '{"error": "Unauthorized"}' + mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") + mock_http.return_value = mock_response + + factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, firecrawl_credentials) + with pytest.raises((httpx.HTTPError, Exception)): + factory.validate_credentials() + + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_network_failure_recovery( + self, + mock_http, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.side_effect = httpx.RequestError("Network timeout") + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + + with pytest.raises(httpx.RequestError): + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + db_session_with_containers.expire_all() + bindings = db_session_with_containers.query(DataSourceApiKeyAuthBinding).filter_by(tenant_id=tenant_id_1).all() + assert len(bindings) == 0 + + @pytest.mark.parametrize( + ("provider", "credentials"), + [ + (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}), + (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}), + (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}), + ], + ) + def test_all_providers_factory_creation(self, provider, credentials): + auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) + assert auth_class is not None + + factory = ApiKeyAuthFactory(provider, credentials) + assert factory.auth is not None + + @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") + @patch("services.auth.firecrawl.firecrawl.httpx.post") + def test_get_auth_credentials_returns_stored_credentials( + self, + mock_http, + mock_encrypt, + flask_app_with_containers, + db_session_with_containers, + tenant_id_1, + category, + firecrawl_credentials, + ): + mock_http.return_value = self._create_success_response() + mock_encrypt.return_value = "encrypted_key" + + args = {"category": category, "provider": AuthType.FIRECRAWL, "credentials": firecrawl_credentials} + ApiKeyAuthService.create_provider_auth(tenant_id_1, args) + + db_session_with_containers.expire_all() + + result = ApiKeyAuthService.get_auth_credentials(tenant_id_1, category, AuthType.FIRECRAWL) + assert result is not None + assert result["config"]["api_key"] == "encrypted_key" + + def _create_success_response(self, status_code=200): + mock_response = Mock() + mock_response.status_code = status_code + mock_response.json.return_value = {"status": "success"} + mock_response.raise_for_status.return_value = None + return mock_response diff --git a/api/dify_graph/model_runtime/callbacks/__init__.py b/api/tests/test_containers_integration_tests/services/enterprise/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/callbacks/__init__.py rename to api/tests/test_containers_integration_tests/services/enterprise/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py new file mode 100644 index 0000000000..4e8255d8ed --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/enterprise/test_account_deletion_sync.py @@ -0,0 +1,200 @@ +"""Integration tests for account deletion synchronization. + +Verifies enterprise account deletion sync functionality including +Redis queuing, error handling, and community vs enterprise behavior. +""" + +from __future__ import annotations + +from unittest.mock import patch +from uuid import uuid4 + +import pytest +from redis import RedisError + +from extensions.ext_redis import redis_client +from models.account import TenantAccountJoin +from services.enterprise.account_deletion_sync import ( + _queue_task, + sync_account_deletion, + sync_workspace_member_removal, +) + + +class TestQueueTask: + def test_queue_task_success(self): + workspace_id = str(uuid4()) + member_id = str(uuid4()) + + result = _queue_task(workspace_id=workspace_id, member_id=member_id, source="test_source") + + assert result is True + + import json + + raw = redis_client.rpop("enterprise:member:sync:queue") + assert raw is not None + task_data = json.loads(raw) + assert task_data["workspace_id"] == workspace_id + assert task_data["member_id"] == member_id + assert task_data["source"] == "test_source" + assert task_data["type"] == "sync_member_deletion_from_workspace" + assert task_data["retry_count"] == 0 + assert "task_id" in task_data + assert "created_at" in task_data + + def test_queue_task_redis_error(self, caplog): + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + mock_redis.lpush.side_effect = RedisError("Connection failed") + + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + def test_queue_task_type_error(self, caplog): + with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: + mock_redis.lpush.side_effect = TypeError("Cannot serialize") + + result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") + + assert result is False + assert "Failed to queue account deletion sync" in caplog.text + + +class TestSyncWorkspaceMemberRemoval: + @pytest.fixture + def mock_queue_task(self): + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): + workspace_id = str(uuid4()) + member_id = str(uuid4()) + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source="removed") + + assert result is True + mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source="removed") + + def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + result = sync_workspace_member_removal( + workspace_id=str(uuid4()), member_id=str(uuid4()), source="test_source" + ) + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_workspace_member_removal( + workspace_id=str(uuid4()), member_id=str(uuid4()), source="test_source" + ) + + assert result is False + + +class TestSyncAccountDeletion: + @pytest.fixture + def mock_queue_task(self): + with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: + mock_queue.return_value = True + yield mock_queue + + def test_sync_account_deletion_enterprise_disabled(self, mock_queue_task): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + + result = sync_account_deletion(account_id=str(uuid4()), source="account_deleted") + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_multiple_workspaces( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_ids = [str(uuid4()) for _ in range(3)] + + for tenant_id in tenant_ids: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is True + assert mock_queue_task.call_count == 3 + + queued_workspace_ids = {call.kwargs["workspace_id"] for call in mock_queue_task.call_args_list} + assert queued_workspace_ids == set(tenant_ids) + + def test_sync_account_deletion_no_workspaces( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=str(uuid4()), source="account_deleted") + + assert result is True + mock_queue_task.assert_not_called() + + def test_sync_account_deletion_partial_failure( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_ids = [str(uuid4()) for _ in range(3)] + fail_tenant = tenant_ids[1] + + for tenant_id in tenant_ids: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + def queue_side_effect(workspace_id, member_id, source): + return workspace_id != fail_tenant + + mock_queue_task.side_effect = queue_side_effect + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is False + assert mock_queue_task.call_count == 3 + + def test_sync_account_deletion_all_failures( + self, flask_app_with_containers, db_session_with_containers, mock_queue_task + ): + account_id = str(uuid4()) + tenant_id = str(uuid4()) + + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account_id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + mock_queue_task.return_value = False + + with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + result = sync_account_deletion(account_id=account_id, source="account_deleted") + + assert result is False + mock_queue_task.assert_called_once() diff --git a/api/dify_graph/model_runtime/errors/__init__.py b/api/tests/test_containers_integration_tests/services/plugin/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/errors/__init__.py rename to api/tests/test_containers_integration_tests/services/plugin/__init__.py diff --git a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py similarity index 78% rename from api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py rename to api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py index bfa9fe976b..ce9f10e207 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_parameter_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_parameter_service.py @@ -6,10 +6,14 @@ HIDDEN_VALUE replacement, and error handling for missing records. from __future__ import annotations +import json from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from core.plugin.entities.plugin_daemon import CredentialType +from models.tools import BuiltinToolProvider from services.plugin.plugin_parameter_service import PluginParameterService @@ -39,67 +43,73 @@ class TestGetDynamicSelectOptionsTool: @patch("services.plugin.plugin_parameter_service.DynamicSelectClient") @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") - @patch("services.plugin.plugin_parameter_service.db") @patch("services.plugin.plugin_parameter_service.ToolManager") - def test_fetches_credentials_with_credential_id(self, mock_tool_mgr, mock_db, mock_encrypter_fn, mock_client_cls): + def test_fetches_credentials_with_credential_id( + self, + mock_tool_mgr, + mock_encrypter_fn, + mock_client_cls, + flask_app_with_containers, + db_session_with_containers, + ): + tenant_id = str(uuid4()) provider_ctrl = MagicMock() provider_ctrl.need_credentials = True mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl encrypter = MagicMock() encrypter.decrypt.return_value = {"api_key": "decrypted"} mock_encrypter_fn.return_value = (encrypter, None) + mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] - # Mock the Session/query chain - db_record = MagicMock() - db_record.credentials = {"api_key": "encrypted"} - db_record.credential_type = "api_key" + db_record = BuiltinToolProvider( + tenant_id=tenant_id, + user_id=str(uuid4()), + provider="google", + name="API KEY 1", + encrypted_credentials=json.dumps({"api_key": "encrypted"}), + credential_type=CredentialType.API_KEY, + ) + db_session_with_containers.add(db_record) + db_session_with_containers.commit() - with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_session.query.return_value.where.return_value.first.return_value = db_record - mock_client_cls.return_value.fetch_dynamic_select_options.return_value.options = ["opt1"] - - result = PluginParameterService.get_dynamic_select_options( - tenant_id="t1", - user_id="u1", - plugin_id="p1", - provider="google", - action="search", - parameter="engine", - credential_id="cred-1", - provider_type="tool", - ) + result = PluginParameterService.get_dynamic_select_options( + tenant_id=tenant_id, + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=db_record.id, + provider_type="tool", + ) assert result == ["opt1"] @patch("services.plugin.plugin_parameter_service.create_tool_provider_encrypter") - @patch("services.plugin.plugin_parameter_service.db") @patch("services.plugin.plugin_parameter_service.ToolManager") - def test_raises_when_tool_provider_not_found(self, mock_tool_mgr, mock_db, mock_encrypter_fn): + def test_raises_when_tool_provider_not_found( + self, + mock_tool_mgr, + mock_encrypter_fn, + flask_app_with_containers, + db_session_with_containers, + ): provider_ctrl = MagicMock() provider_ctrl.need_credentials = True mock_tool_mgr.get_builtin_provider.return_value = provider_ctrl mock_encrypter_fn.return_value = (MagicMock(), None) - with patch("services.plugin.plugin_parameter_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - mock_session.query.return_value.where.return_value.order_by.return_value.first.return_value = None - - with pytest.raises(ValueError, match="not found"): - PluginParameterService.get_dynamic_select_options( - tenant_id="t1", - user_id="u1", - plugin_id="p1", - provider="google", - action="search", - parameter="engine", - credential_id=None, - provider_type="tool", - ) + with pytest.raises(ValueError, match="not found"): + PluginParameterService.get_dynamic_select_options( + tenant_id=str(uuid4()), + user_id="u1", + plugin_id="p1", + provider="google", + action="search", + parameter="engine", + credential_id=None, + provider_type="tool", + ) class TestGetDynamicSelectOptionsTrigger: diff --git a/api/tests/unit_tests/services/plugin/test_plugin_service.py b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py similarity index 78% rename from api/tests/unit_tests/services/plugin/test_plugin_service.py rename to api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py index 09b9ab498b..0cdae572fb 100644 --- a/api/tests/unit_tests/services/plugin/test_plugin_service.py +++ b/api/tests/test_containers_integration_tests/services/plugin/test_plugin_service.py @@ -8,15 +8,27 @@ verification, marketplace upgrade flows, and uninstall with credential cleanup. from __future__ import annotations from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from sqlalchemy import select from core.plugin.entities.plugin import PluginInstallationSource from core.plugin.entities.plugin_daemon import PluginVerification +from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider from services.errors.plugin import PluginInstallationForbiddenError from services.feature_service import PluginInstallationScope from services.plugin.plugin_service import PluginService -from tests.unit_tests.services.plugin.conftest import make_features + + +def _make_features( + restrict_to_marketplace: bool = False, + scope: PluginInstallationScope = PluginInstallationScope.ALL, +) -> MagicMock: + features = MagicMock() + features.plugin_installation_permission.restrict_to_marketplace_only = restrict_to_marketplace + features.plugin_installation_permission.plugin_installation_scope = scope + return features class TestFetchLatestPluginVersion: @@ -80,14 +92,14 @@ class TestFetchLatestPluginVersion: class TestCheckMarketplaceOnlyPermission: @patch("services.plugin.plugin_service.FeatureService") def test_raises_when_restricted(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=True) + mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=True) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_marketplace_only_permission() @patch("services.plugin.plugin_service.FeatureService") def test_passes_when_not_restricted(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(restrict_to_marketplace=False) + mock_fs.get_system_features.return_value = _make_features(restrict_to_marketplace=False) PluginService._check_marketplace_only_permission() # should not raise @@ -95,7 +107,7 @@ class TestCheckMarketplaceOnlyPermission: class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_only_allows_langgenius(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) verification = MagicMock() verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius @@ -103,14 +115,14 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_only_rejects_third_party(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.OFFICIAL_ONLY) with pytest.raises(PluginInstallationForbiddenError): PluginService._check_plugin_installation_scope(None) @patch("services.plugin.plugin_service.FeatureService") def test_official_and_partners_allows_partner(self, mock_fs): - mock_fs.get_system_features.return_value = make_features( + mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS ) verification = MagicMock() @@ -120,7 +132,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_official_and_partners_rejects_none(self, mock_fs): - mock_fs.get_system_features.return_value = make_features( + mock_fs.get_system_features.return_value = _make_features( scope=PluginInstallationScope.OFFICIAL_AND_SPECIFIC_PARTNERS ) @@ -129,7 +141,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_none_scope_always_raises(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.NONE) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.NONE) verification = MagicMock() verification.authorized_category = PluginVerification.AuthorizedCategory.Langgenius @@ -138,7 +150,7 @@ class TestCheckPluginInstallationScope: @patch("services.plugin.plugin_service.FeatureService") def test_all_scope_passes_any(self, mock_fs): - mock_fs.get_system_features.return_value = make_features(scope=PluginInstallationScope.ALL) + mock_fs.get_system_features.return_value = _make_features(scope=PluginInstallationScope.ALL) PluginService._check_plugin_installation_scope(None) # should not raise @@ -209,9 +221,9 @@ class TestUpgradePluginWithMarketplace: @patch("services.plugin.plugin_service.dify_config") def test_skips_download_when_already_installed(self, mock_config, mock_installer_cls, mock_fs, mock_marketplace): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value - installer.fetch_plugin_manifest.return_value = MagicMock() # no exception = already installed + installer.fetch_plugin_manifest.return_value = MagicMock() installer.upgrade_plugin.return_value = MagicMock() PluginService.upgrade_plugin_with_marketplace("t1", "old-uid", "new-uid") @@ -225,7 +237,7 @@ class TestUpgradePluginWithMarketplace: @patch("services.plugin.plugin_service.dify_config") def test_downloads_when_not_installed(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_download.return_value = b"pkg-bytes" @@ -244,7 +256,7 @@ class TestUpgradePluginWithGithub: @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") def test_checks_marketplace_permission_and_delegates(self, mock_installer_cls, mock_fs): - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.upgrade_plugin.return_value = MagicMock() @@ -259,7 +271,7 @@ class TestUploadPkg: @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") def test_runs_permission_and_scope_checks(self, mock_installer_cls, mock_fs): - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() upload_resp = MagicMock() upload_resp.verification = None mock_installer_cls.return_value.upload_pkg.return_value = upload_resp @@ -283,7 +295,7 @@ class TestInstallFromMarketplacePkg: @patch("services.plugin.plugin_service.dify_config") def test_downloads_when_not_cached(self, mock_config, mock_installer_cls, mock_fs, mock_download): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.side_effect = RuntimeError("not found") mock_download.return_value = b"pkg" @@ -298,14 +310,14 @@ class TestInstallFromMarketplacePkg: assert result == "task-id" installer.install_from_identifiers.assert_called_once() call_args = installer.install_from_identifiers.call_args[0] - assert call_args[1] == ["resolved-uid"] # uses response uid, not input + assert call_args[1] == ["resolved-uid"] @patch("services.plugin.plugin_service.FeatureService") @patch("services.plugin.plugin_service.PluginInstaller") @patch("services.plugin.plugin_service.dify_config") def test_uses_cached_when_already_downloaded(self, mock_config, mock_installer_cls, mock_fs): mock_config.MARKETPLACE_ENABLED = True - mock_fs.get_system_features.return_value = make_features() + mock_fs.get_system_features.return_value = _make_features() installer = mock_installer_cls.return_value installer.fetch_plugin_manifest.return_value = MagicMock() decode_resp = MagicMock() @@ -317,7 +329,7 @@ class TestInstallFromMarketplacePkg: installer.install_from_identifiers.assert_called_once() call_args = installer.install_from_identifiers.call_args[0] - assert call_args[1] == ["uid-1"] # uses original uid + assert call_args[1] == ["uid-1"] class TestUninstall: @@ -332,26 +344,70 @@ class TestUninstall: assert result is True installer.uninstall.assert_called_once_with("t1", "install-1") - @patch("services.plugin.plugin_service.db") @patch("services.plugin.plugin_service.PluginInstaller") - def test_cleans_credentials_when_plugin_found(self, mock_installer_cls, mock_db): + def test_cleans_credentials_when_plugin_found( + self, mock_installer_cls, flask_app_with_containers, db_session_with_containers + ): + tenant_id = str(uuid4()) + plugin_id = "org/myplugin" + provider_name = f"{plugin_id}/model-provider" + + credential = ProviderCredential( + tenant_id=tenant_id, + provider_name=provider_name, + credential_name="default", + encrypted_config="{}", + ) + db_session_with_containers.add(credential) + db_session_with_containers.flush() + credential_id = credential.id + + provider = Provider( + tenant_id=tenant_id, + provider_name=provider_name, + credential_id=credential_id, + ) + db_session_with_containers.add(provider) + db_session_with_containers.flush() + provider_id = provider.id + + pref = TenantPreferredModelProvider( + tenant_id=tenant_id, + provider_name=provider_name, + preferred_provider_type="custom", + ) + db_session_with_containers.add(pref) + db_session_with_containers.commit() + plugin = MagicMock() plugin.installation_id = "install-1" - plugin.plugin_id = "org/myplugin" + plugin.plugin_id = plugin_id installer = mock_installer_cls.return_value installer.list_plugins.return_value = [plugin] installer.uninstall.return_value = True - # Mock Session context manager - mock_session = MagicMock() - mock_db.engine = MagicMock() - mock_session.scalars.return_value.all.return_value = [] # no credentials found - - with patch("services.plugin.plugin_service.Session") as mock_session_cls: - mock_session_cls.return_value.__enter__ = MagicMock(return_value=mock_session) - mock_session_cls.return_value.__exit__ = MagicMock(return_value=False) - - result = PluginService.uninstall("t1", "install-1") + with patch("services.plugin.plugin_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = False + result = PluginService.uninstall(tenant_id, "install-1") assert result is True installer.uninstall.assert_called_once() + + db_session_with_containers.expire_all() + + remaining_creds = db_session_with_containers.scalars( + select(ProviderCredential).where(ProviderCredential.id == credential_id) + ).all() + assert len(remaining_creds) == 0 + + updated_provider = db_session_with_containers.get(Provider, provider_id) + assert updated_provider is not None + assert updated_provider.credential_id is None + + remaining_prefs = db_session_with_containers.scalars( + select(TenantPreferredModelProvider).where( + TenantPreferredModelProvider.tenant_id == tenant_id, + TenantPreferredModelProvider.provider_name == provider_name, + ) + ).all() + assert len(remaining_prefs) == 0 diff --git a/api/dify_graph/model_runtime/model_providers/__base/__init__.py b/api/tests/test_containers_integration_tests/services/recommend_app/__init__.py similarity index 100% rename from api/dify_graph/model_runtime/model_providers/__base/__init__.py rename to api/tests/test_containers_integration_tests/services/recommend_app/__init__.py diff --git a/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py new file mode 100644 index 0000000000..2b842629a7 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/recommend_app/test_database_retrieval.py @@ -0,0 +1,184 @@ +from __future__ import annotations + +from unittest.mock import patch +from uuid import uuid4 + +from models.model import App, RecommendedApp, Site +from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval +from services.recommend_app.recommend_app_type import RecommendAppType + + +def _create_app(db_session, *, tenant_id: str, is_public: bool = True) -> App: + app = App( + tenant_id=tenant_id, + name=f"app-{uuid4()}", + mode="chat", + enable_site=True, + enable_api=True, + is_public=is_public, + ) + app.id = str(uuid4()) + db_session.add(app) + db_session.commit() + return app + + +def _create_site(db_session, *, app_id: str) -> Site: + site = Site( + app_id=app_id, + title=f"site-{uuid4()}", + default_language="en-US", + customize_token_strategy="not_allow", + description="desc", + copyright="copy", + privacy_policy="pp", + custom_disclaimer="cd", + ) + site.id = str(uuid4()) + db_session.add(site) + db_session.commit() + return site + + +def _create_recommended_app( + db_session, + *, + app_id: str, + category: str = "chat", + language: str = "en-US", + is_listed: bool = True, + position: int = 1, +) -> RecommendedApp: + rec = RecommendedApp( + app_id=app_id, + description={"en-US": "test"}, + copyright="copy", + privacy_policy="pp", + category=category, + language=language, + is_listed=is_listed, + position=position, + ) + rec.id = str(uuid4()) + db_session.add(rec) + db_session.commit() + return rec + + +class TestDatabaseRecommendAppRetrieval: + def test_get_type(self): + assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE + + def test_get_recommended_apps_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_apps_from_db", + return_value={"recommended_apps": [], "categories": []}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") + mock_fetch.assert_called_once_with("en-US") + assert result == {"recommended_apps": [], "categories": []} + + def test_get_recommend_app_detail_delegates(self): + with patch.object( + DatabaseRecommendAppRetrieval, + "fetch_recommended_app_detail_from_db", + return_value={"id": "app-1"}, + ) as mock_fetch: + result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") + mock_fetch.assert_called_once_with("app-1") + assert result == {"id": "app-1"} + + +class TestFetchRecommendedAppsFromDb: + def test_returns_apps_and_sorted_categories(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id, category="writing") + + app2 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app2.id) + _create_recommended_app(db_session_with_containers, app_id=app2.id, category="assistant") + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id in app_ids + assert app2.id in app_ids + assert "assistant" in result["categories"] + assert "writing" in result["categories"] + + def test_falls_back_to_default_language_when_empty(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id, language="en-US") + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id in app_ids + + def test_skips_non_public_apps(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id not in app_ids + + def test_skips_apps_without_site(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") + + app_ids = {r["app_id"] for r in result["recommended_apps"]} + assert app1.id not in app_ids + + +class TestFetchRecommendedAppDetailFromDb: + def test_returns_none_when_not_listed(self, flask_app_with_containers, db_session_with_containers): + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(str(uuid4())) + + assert result is None + + def test_returns_none_when_app_not_public(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id, is_public=False) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(app1.id) + + assert result is None + + @patch("services.recommend_app.database.database_retrieval.AppDslService") + def test_returns_detail_on_success(self, mock_dsl, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + app1 = _create_app(db_session_with_containers, tenant_id=tenant_id) + _create_site(db_session_with_containers, app_id=app1.id) + _create_recommended_app(db_session_with_containers, app_id=app1.id) + mock_dsl.export_dsl.return_value = "exported_yaml" + + db_session_with_containers.expire_all() + + result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db(app1.id) + + assert result is not None + assert result["id"] == app1.id + assert result["export_data"] == "exported_yaml" diff --git a/api/tests/test_containers_integration_tests/services/test_agent_service.py b/api/tests/test_containers_integration_tests/services/test_agent_service.py index b51fbc3a42..4f3c0e4200 100644 --- a/api/tests/test_containers_integration_tests/services/test_agent_service.py +++ b/api/tests/test_containers_integration_tests/services/test_agent_service.py @@ -28,7 +28,7 @@ class TestAgentService: patch("services.agent_service.current_user", create_autospec(Account, instance=True)) as mock_current_user, patch("services.app_service.FeatureService", autospec=True) as mock_feature_service, patch("services.app_service.EnterpriseService", autospec=True) as mock_enterprise_service, - patch("services.app_service.ModelManager", autospec=True) as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant", autospec=True) as mock_model_manager, patch("services.account_service.FeatureService", autospec=True) as mock_account_feature_service, ): # Setup default mock returns for agent service @@ -841,7 +841,8 @@ class TestAgentService: app, account = self._create_test_app_and_account(db_session_with_containers, mock_external_service_dependencies) conversation, message = self._create_test_conversation_and_message(db_session_with_containers, app, account) - from dify_graph.file import FileTransferMethod, FileType + from graphon.file import FileTransferMethod, FileType + from models.enums import CreatorUserRole # Add files to message diff --git a/api/tests/unit_tests/services/test_api_token_service.py b/api/tests/test_containers_integration_tests/services/test_api_token_service.py similarity index 71% rename from api/tests/unit_tests/services/test_api_token_service.py rename to api/tests/test_containers_integration_tests/services/test_api_token_service.py index ad4de93b25..a2028d3ed3 100644 --- a/api/tests/unit_tests/services/test_api_token_service.py +++ b/api/tests/test_containers_integration_tests/services/test_api_token_service.py @@ -1,80 +1,63 @@ +from __future__ import annotations + from datetime import datetime -from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest from werkzeug.exceptions import Unauthorized import services.api_token_service as api_token_service_module +from models.model import ApiToken from services.api_token_service import ApiTokenCache, CachedApiToken -@pytest.fixture -def mock_db_session(): - """Fixture providing common DB session mocking for query_token_from_db tests.""" - fake_engine = MagicMock() - - session = MagicMock() - session_context = MagicMock() - session_context.__enter__.return_value = session - session_context.__exit__.return_value = None - - with ( - patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)), - patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class, - patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, - patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, - ): - yield { - "session": session, - "mock_session_class": mock_session_class, - "mock_cache_set": mock_cache_set, - "mock_record_usage": mock_record_usage, - "fake_engine": fake_engine, - } - - class TestQueryTokenFromDb: - def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session): - """Test DB lookup success path caches token and records usage.""" - # Arrange - auth_token = "token-123" - scope = "app" - api_token = MagicMock() + def test_should_return_api_token_and_cache_when_token_exists( + self, flask_app_with_containers, db_session_with_containers + ): + tenant_id = str(uuid4()) + app_id = str(uuid4()) + token_value = f"app-test-{uuid4()}" - mock_db_session["session"].scalar.return_value = api_token + api_token = ApiToken() + api_token.id = str(uuid4()) + api_token.app_id = app_id + api_token.tenant_id = tenant_id + api_token.type = "app" + api_token.token = token_value + db_session_with_containers.add(api_token) + db_session_with_containers.commit() - # Act - result = api_token_service_module.query_token_from_db(auth_token, scope) + with ( + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + result = api_token_service_module.query_token_from_db(token_value, "app") - # Assert - assert result == api_token - mock_db_session["mock_session_class"].assert_called_once_with( - mock_db_session["fake_engine"], expire_on_commit=False - ) - mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token) - mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope) + assert result.id == api_token.id + assert result.token == token_value + mock_cache_set.assert_called_once() + mock_record_usage.assert_called_once_with(token_value, "app") - def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session): - """Test DB lookup miss path caches null marker and raises Unauthorized.""" - # Arrange - auth_token = "missing-token" - scope = "app" + def test_should_cache_null_and_raise_unauthorized_when_token_not_found( + self, flask_app_with_containers, db_session_with_containers + ): + with ( + patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set, + patch.object(api_token_service_module, "record_token_usage") as mock_record_usage, + ): + with pytest.raises(Unauthorized, match="Access token is invalid"): + api_token_service_module.query_token_from_db(f"missing-{uuid4()}", "app") - mock_db_session["session"].scalar.return_value = None - - # Act / Assert - with pytest.raises(Unauthorized, match="Access token is invalid"): - api_token_service_module.query_token_from_db(auth_token, scope) - - mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None) - mock_db_session["mock_record_usage"].assert_not_called() + mock_cache_set.assert_called_once() + call_args = mock_cache_set.call_args[0] + assert call_args[2] is None # cached None + mock_record_usage.assert_not_called() class TestRecordTokenUsage: def test_should_write_active_key_with_iso_timestamp_and_ttl(self): - """Test record_token_usage writes usage timestamp with one-hour TTL.""" - # Arrange auth_token = "token-123" scope = "dataset" fixed_time = datetime(2026, 2, 24, 12, 0, 0) @@ -84,26 +67,18 @@ class TestRecordTokenUsage: patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time), patch.object(api_token_service_module, "redis_client") as mock_redis, ): - # Act api_token_service_module.record_token_usage(auth_token, scope) - # Assert mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600) def test_should_not_raise_when_redis_write_fails(self): - """Test record_token_usage swallows Redis errors.""" - # Arrange with patch.object(api_token_service_module, "redis_client") as mock_redis: mock_redis.set.side_effect = Exception("redis unavailable") - - # Act / Assert api_token_service_module.record_token_usage("token-123", "app") class TestFetchTokenWithSingleFlight: def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self): - """Test single-flight returns cache when another request already populated it.""" - # Arrange auth_token = "token-123" scope = "app" cached_token = CachedApiToken( @@ -115,39 +90,26 @@ class TestFetchTokenWithSingleFlight: last_used_at=None, created_at=None, ) - lock = MagicMock() lock.acquire.return_value = True with ( patch.object(api_token_service_module, "redis_client") as mock_redis, - patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get, + patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token), patch.object(api_token_service_module, "query_token_from_db") as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == cached_token - mock_redis.lock.assert_called_once_with( - f"api_token_query_lock:{scope}:{auth_token}", - timeout=10, - blocking_timeout=5, - ) lock.acquire.assert_called_once_with(blocking=True) lock.release.assert_called_once() - mock_cache_get.assert_called_once_with(auth_token, scope) mock_query_db.assert_not_called() def test_should_query_db_when_lock_acquired_and_cache_missed(self): - """Test single-flight queries DB when cache remains empty after lock acquisition.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.return_value = True @@ -157,22 +119,16 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_query_db.assert_called_once_with(auth_token, scope) lock.release.assert_called_once() def test_should_query_db_directly_when_lock_not_acquired(self): - """Test lock timeout branch falls back to direct DB query.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.return_value = False @@ -182,19 +138,14 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_cache_get.assert_not_called() mock_query_db.assert_called_once_with(auth_token, scope) lock.release.assert_not_called() def test_should_reraise_unauthorized_from_db_query(self): - """Test Unauthorized from DB query is propagated unchanged.""" - # Arrange auth_token = "token-123" scope = "app" lock = MagicMock() @@ -210,20 +161,15 @@ class TestFetchTokenWithSingleFlight: ), ): mock_redis.lock.return_value = lock - - # Act / Assert with pytest.raises(Unauthorized, match="Access token is invalid"): api_token_service_module.fetch_token_with_single_flight(auth_token, scope) lock.release.assert_called_once() def test_should_fallback_to_db_query_when_lock_raises_exception(self): - """Test Redis lock errors fall back to direct DB query.""" - # Arrange auth_token = "token-123" scope = "app" db_token = MagicMock() - lock = MagicMock() lock.acquire.side_effect = RuntimeError("redis lock error") @@ -232,11 +178,8 @@ class TestFetchTokenWithSingleFlight: patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db, ): mock_redis.lock.return_value = lock - - # Act result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope) - # Assert assert result == db_token mock_query_db.assert_called_once_with(auth_token, scope) @@ -244,8 +187,6 @@ class TestFetchTokenWithSingleFlight: class TestApiTokenCacheTenantBranches: @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis): - """Test scoped delete removes cache key and tenant index membership.""" - # Arrange token = "token-123" scope = "app" cache_key = ApiTokenCache._make_cache_key(token, scope) @@ -261,18 +202,14 @@ class TestApiTokenCacheTenantBranches: mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8") with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index: - # Act result = ApiTokenCache.delete(token, scope) - # Assert assert result is True mock_redis.delete.assert_called_once_with(cache_key) mock_remove_index.assert_called_once_with("tenant-1", cache_key) @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis): - """Test tenant invalidation deletes indexed cache entries and index key.""" - # Arrange tenant_id = "tenant-1" index_key = ApiTokenCache._make_tenant_index_key(tenant_id) mock_redis.smembers.return_value = { @@ -280,10 +217,8 @@ class TestApiTokenCacheTenantBranches: b"api_token:any:token-2", } - # Act result = ApiTokenCache.invalidate_by_tenant(tenant_id) - # Assert assert result is True mock_redis.smembers.assert_called_once_with(index_key) mock_redis.delete.assert_any_call("api_token:app:token-1") @@ -293,7 +228,6 @@ class TestApiTokenCacheTenantBranches: class TestApiTokenCacheCoreBranches: def test_cached_api_token_repr_should_include_id_and_type(self): - """Test CachedApiToken __repr__ includes key identity fields.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -303,11 +237,9 @@ class TestApiTokenCacheCoreBranches: last_used_at=None, created_at=None, ) - assert repr(token) == "" def test_serialize_token_should_handle_cached_api_token_instances(self): - """Test serialization path when input is already a CachedApiToken.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -317,35 +249,25 @@ class TestApiTokenCacheCoreBranches: last_used_at=None, created_at=None, ) - serialized = ApiTokenCache._serialize_token(token) - assert isinstance(serialized, bytes) assert b'"id":"id-123"' in serialized - assert b'"token":"token-123"' in serialized def test_deserialize_token_should_return_none_for_null_markers(self): - """Test null cache marker deserializes to None.""" assert ApiTokenCache._deserialize_token("null") is None assert ApiTokenCache._deserialize_token(b"null") is None def test_deserialize_token_should_return_none_for_invalid_payload(self): - """Test invalid serialized payload returns None.""" assert ApiTokenCache._deserialize_token("not-json") is None @patch("services.api_token_service.redis_client") def test_get_should_return_none_on_cache_miss(self, mock_redis): - """Test cache miss branch in ApiTokenCache.get.""" mock_redis.get.return_value = None - result = ApiTokenCache.get("token-123", "app") - assert result is None - mock_redis.get.assert_called_once_with("api_token:app:token-123") @patch("services.api_token_service.redis_client") def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis): - """Test cache hit branch in ApiTokenCache.get.""" token = CachedApiToken( id="id-123", app_id="app-123", @@ -356,48 +278,34 @@ class TestApiTokenCacheCoreBranches: created_at=None, ) mock_redis.get.return_value = token.model_dump_json().encode("utf-8") - result = ApiTokenCache.get("token-123", "app") - assert isinstance(result, CachedApiToken) assert result.id == "id-123" @patch("services.api_token_service.redis_client") def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): - """Test tenant index update exits early for missing tenant id.""" ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123") - mock_redis.sadd.assert_not_called() - mock_redis.expire.assert_not_called() @patch("services.api_token_service.redis_client") def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis): - """Test tenant index update handles Redis write errors gracefully.""" mock_redis.sadd.side_effect = Exception("redis down") - ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123") - mock_redis.sadd.assert_called_once() @patch("services.api_token_service.redis_client") def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis): - """Test tenant index removal exits early for missing tenant id.""" ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123") - mock_redis.srem.assert_not_called() @patch("services.api_token_service.redis_client") def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis): - """Test tenant index removal handles Redis errors gracefully.""" mock_redis.srem.side_effect = Exception("redis down") - ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123") - mock_redis.srem.assert_called_once() @patch("services.api_token_service.redis_client") def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis): - """Test set returns False when Redis setex fails.""" mock_redis.setex.side_effect = Exception("redis write failed") api_token = MagicMock() api_token.id = "id-123" @@ -407,60 +315,41 @@ class TestApiTokenCacheCoreBranches: api_token.token = "token-123" api_token.last_used_at = None api_token.created_at = None - result = ApiTokenCache.set("token-123", "app", api_token) - assert result is False @patch("services.api_token_service.redis_client") def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis): - """Test delete(scope=None) returns False when scan_iter raises.""" mock_redis.scan_iter.side_effect = Exception("scan failed") - result = ApiTokenCache.delete("token-123", None) - assert result is False @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis): - """Test scoped delete still succeeds when tenant lookup from cache fails.""" token = "token-123" scope = "app" cache_key = ApiTokenCache._make_cache_key(token, scope) mock_redis.get.side_effect = Exception("get failed") - result = ApiTokenCache.delete(token, scope) - assert result is True mock_redis.delete.assert_called_once_with(cache_key) @patch("services.api_token_service.redis_client") def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis): - """Test scoped delete returns False when delete operation fails.""" - token = "token-123" - scope = "app" mock_redis.get.return_value = None mock_redis.delete.side_effect = Exception("delete failed") - - result = ApiTokenCache.delete(token, scope) - + result = ApiTokenCache.delete("token-123", "app") assert result is False @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis): - """Test tenant invalidation returns True when tenant index is empty.""" mock_redis.smembers.return_value = set() - result = ApiTokenCache.invalidate_by_tenant("tenant-123") - assert result is True mock_redis.delete.assert_not_called() @patch("services.api_token_service.redis_client") def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis): - """Test tenant invalidation returns False when Redis operation fails.""" mock_redis.smembers.side_effect = Exception("redis failed") - result = ApiTokenCache.invalidate_by_tenant("tenant-123") - assert result is False diff --git a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py index 8a362e1f5e..33955d5d84 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_dsl_service.py @@ -26,7 +26,7 @@ class TestAppDslService: patch("services.app_dsl_service.redis_client") as mock_redis_client, patch("services.app_dsl_service.app_was_created") as mock_app_was_created, patch("services.app_dsl_service.app_model_config_was_updated") as mock_app_model_config_was_updated, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, ): diff --git a/api/tests/test_containers_integration_tests/services/test_app_service.py b/api/tests/test_containers_integration_tests/services/test_app_service.py index a83af30fb9..fa57dd4a6f 100644 --- a/api/tests/test_containers_integration_tests/services/test_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_app_service.py @@ -23,7 +23,7 @@ class TestAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py index 42a2215896..fb0adbbcc2 100644 --- a/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py +++ b/api/tests/test_containers_integration_tests/services/test_conversation_variable_updater.py @@ -3,9 +3,9 @@ from uuid import uuid4 import pytest +from graphon.variables import StringVariable from sqlalchemy.orm import sessionmaker -from dify_graph.variables import StringVariable from extensions.ext_database import db from models.workflow import ConversationVariable from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service.py b/api/tests/test_containers_integration_tests/services/test_dataset_service.py index 0702680f5c..f9bfa570cb 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service.py @@ -9,11 +9,11 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.retrieval.retrieval_methods import RetrievalMethod -from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, DatasetPermissionEnum, Document, ExternalKnowledgeBindings, Pipeline from models.enums import DatasetRuntimeMode, DataSourceType, DocumentCreatedFrom, IndexingStatus @@ -174,7 +174,7 @@ class TestDatasetServiceCreateDataset: embedding_model = DatasetServiceIntegrationDataFactory.create_embedding_model() # Act - with patch("services.dataset_service.ModelManager") as mock_model_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager: mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model result = DatasetService.create_empty_dataset( @@ -264,7 +264,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_reranking_model_setting") as mock_check_reranking, ): mock_model_manager.return_value.get_default_model_instance.return_value = embedding_model @@ -297,7 +297,7 @@ class TestDatasetServiceCreateDataset: # Act with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch("services.dataset_service.DatasetService.check_embedding_model_setting") as mock_check_embedding, ): mock_model_manager.return_value.get_model_instance.return_value = embedding_model diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py index 2899d5b8a5..a814466e14 100644 --- a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -2,10 +2,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm import Session from core.rag.index_processor.constant.index_type import IndexTechniqueType -from dify_graph.model_runtime.entities.model_entities import ModelType from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.dataset import Dataset, ExternalKnowledgeBindings from models.enums import DataSourceType @@ -363,7 +363,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -458,7 +458,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, @@ -544,7 +544,7 @@ class TestDatasetServiceUpdateDataset: with ( patch("services.dataset_service.current_user", user), - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available") diff --git a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py index 376a89d1ce..c8f04e9215 100644 --- a/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_delete_archived_workflow_run.py @@ -5,9 +5,9 @@ Testcontainers integration tests for archived workflow run deletion service. from datetime import UTC, datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import select -from dify_graph.enums import WorkflowExecutionStatus from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowArchiveLog, WorkflowRun from services.retention.workflow_run.delete_archived_workflow_run import ArchivedWorkflowRunDeletion diff --git a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py index 70d05792ce..c46b8fba0b 100644 --- a/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test.py @@ -3,14 +3,14 @@ import uuid from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, ) from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.model import App, AppMode @@ -54,7 +54,7 @@ def _create_app_with_draft_workflow(session, *, delivery_method_id: uuid.UUID) - enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="recipient@example.com")], ), subject="Test {{recipient_email}}", diff --git a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py similarity index 81% rename from api/tests/unit_tests/services/test_human_input_delivery_test_service.py rename to api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py index 3b1c1fcf17..0f252515f7 100644 --- a/api/tests/unit_tests/services/test_human_input_delivery_test_service.py +++ b/api/tests/test_containers_integration_tests/services/test_human_input_delivery_test_service.py @@ -1,18 +1,22 @@ +from __future__ import annotations + from types import SimpleNamespace from unittest.mock import MagicMock, patch +from uuid import uuid4 import pytest +from graphon.runtime import VariablePool from sqlalchemy.engine import Engine from configs import dify_config -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, MemberRecipient, ) -from dify_graph.runtime import VariablePool +from models.account import Account, TenantAccountJoin from services import human_input_delivery_test_service as service_module from services.human_input_delivery_test_service import ( DeliveryTestContext, @@ -28,15 +32,12 @@ from services.human_input_delivery_test_service import ( ) -@pytest.fixture -def mock_db(monkeypatch): - mock_db = MagicMock() - monkeypatch.setattr(service_module, "db", mock_db) - return mock_db - - def _make_valid_email_config(): - return EmailDeliveryConfig(recipients=EmailRecipients(whole_workspace=False, items=[]), subject="Subj", body="Body") + return EmailDeliveryConfig( + recipients=EmailRecipients(include_bound_group=False, items=[]), + subject="Subj", + body="Body", + ) def test_build_form_link(): @@ -87,7 +88,7 @@ class TestDeliveryTestRegistry: with pytest.raises(DeliveryTestUnsupportedError, match="Delivery method does not support test send."): registry.dispatch(context=context, method=method) - def test_default(self, mock_db): + def test_default(self, flask_app_with_containers, db_session_with_containers): registry = DeliveryTestRegistry.default() assert len(registry._handlers) == 1 assert isinstance(registry._handlers[0], EmailDeliveryTestHandler) @@ -246,62 +247,70 @@ class TestEmailDeliveryTestHandler: _, kwargs = mock_mail_send.call_args assert kwargs["subject"] == "Notice BCC:test@example.com" - def test_resolve_recipients(self): + def test_resolve_recipients_external(self): handler = EmailDeliveryTestHandler(session_factory=MagicMock()) - - # Test Case 1: External Recipient method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[ExternalRecipient(email="ext@example.com")], whole_workspace=False), + recipients=EmailRecipients( + items=[ExternalRecipient(email="ext@example.com")], include_bound_group=False + ), subject="", body="", ) ) assert handler._resolve_recipients(tenant_id="t1", method=method) == ["ext@example.com"] - # Test Case 2: Member Recipient + def test_resolve_recipients_member(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + account = Account(name="Test User", email="member@example.com") + db_session_with_containers.add(account) + db_session_with_containers.commit() + + join = TenantAccountJoin(tenant_id=tenant_id, account_id=account.id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + from extensions.ext_database import db + + handler = EmailDeliveryTestHandler(session_factory=db.engine) method = EmailDeliveryMethod( config=EmailDeliveryConfig( - recipients=EmailRecipients(items=[MemberRecipient(user_id="u1")], whole_workspace=False), + recipients=EmailRecipients(items=[MemberRecipient(reference_id=account.id)], include_bound_group=False), subject="", body="", ) ) - handler._query_workspace_member_emails = MagicMock(return_value={"u1": "u1@example.com"}) - assert handler._resolve_recipients(tenant_id="t1", method=method) == ["u1@example.com"] + assert handler._resolve_recipients(tenant_id=tenant_id, method=method) == ["member@example.com"] - # Test Case 3: Whole Workspace + def test_resolve_recipients_whole_workspace(self, flask_app_with_containers, db_session_with_containers): + tenant_id = str(uuid4()) + account1 = Account(name="User 1", email=f"u1-{uuid4()}@example.com") + account2 = Account(name="User 2", email=f"u2-{uuid4()}@example.com") + db_session_with_containers.add_all([account1, account2]) + db_session_with_containers.commit() + + for acc in [account1, account2]: + join = TenantAccountJoin(tenant_id=tenant_id, account_id=acc.id) + db_session_with_containers.add(join) + db_session_with_containers.commit() + + from extensions.ext_database import db + + handler = EmailDeliveryTestHandler(session_factory=db.engine) method = EmailDeliveryMethod( - config=EmailDeliveryConfig(recipients=EmailRecipients(items=[], whole_workspace=True), subject="", body="") + config=EmailDeliveryConfig( + recipients=EmailRecipients(items=[], include_bound_group=True), + subject="", + body="", + ) ) - handler._query_workspace_member_emails = MagicMock( - return_value={"u1": "u1@example.com", "u2": "u2@example.com"} - ) - recipients = handler._resolve_recipients(tenant_id="t1", method=method) - assert set(recipients) == {"u1@example.com", "u2@example.com"} + recipients = handler._resolve_recipients(tenant_id=tenant_id, method=method) + assert set(recipients) == {account1.email, account2.email} - def test_query_workspace_member_emails(self): - mock_session = MagicMock() - mock_session_factory = MagicMock(return_value=mock_session) - mock_session.__enter__.return_value = mock_session - - handler = EmailDeliveryTestHandler(session_factory=mock_session_factory) - - # Empty user_ids + def test_query_workspace_member_emails_empty_ids(self): + handler = EmailDeliveryTestHandler(session_factory=MagicMock()) assert handler._query_workspace_member_emails(tenant_id="t1", user_ids=[]) == {} - # user_ids is None (all) - mock_execute = MagicMock() - mock_session.execute.return_value = mock_execute - mock_execute.all.return_value = [("u1", "u1@example.com")] - - result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=None) - assert result == {"u1": "u1@example.com"} - - # user_ids with values - result = handler._query_workspace_member_emails(tenant_id="t1", user_ids=["u1"]) - assert result == {"u1": "u1@example.com"} - def test_build_substitutions(self): context = DeliveryTestContext( tenant_id="t1", @@ -323,7 +332,6 @@ class TestEmailDeliveryTestHandler: assert subs["form_token"] == "token123" assert "form/token123" in subs["form_link"] - # Without matching recipient subs_no_match = EmailDeliveryTestHandler._build_substitutions( context=context, recipient_email="other@example.com" ) diff --git a/api/tests/test_containers_integration_tests/services/test_message_service.py b/api/tests/test_containers_integration_tests/services/test_message_service.py index 85dc04b162..bdf6d9b951 100644 --- a/api/tests/test_containers_integration_tests/services/test_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_message_service.py @@ -25,7 +25,7 @@ class TestMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.message_service.ModelManager") as mock_model_manager, + patch("services.message_service.ModelManager.for_tenant") as mock_model_manager, patch("services.message_service.WorkflowService") as mock_workflow_service, patch("services.message_service.AdvancedChatAppConfigManager") as mock_app_config_manager, patch("services.message_service.LLMGenerator") as mock_llm_generator, diff --git a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py index 57bbc73b50..2340dd2a03 100644 --- a/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py +++ b/api/tests/test_containers_integration_tests/services/test_messages_clean_service.py @@ -1,14 +1,16 @@ +from __future__ import annotations + import datetime import json import uuid from decimal import Decimal -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.file import FileType from sqlalchemy.orm import Session -from dify_graph.file.enums import FileType from enums.cloud_plan import CloudPlan from extensions.ext_redis import redis_client from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole @@ -1169,3 +1171,66 @@ class TestMessagesCleanServiceIntegration: # Verify all messages were deleted assert db_session_with_containers.query(Message).where(Message.id.in_(msg_ids)).count() == 0 + + def test_from_time_range_validation(self): + """Test that from_time_range raises ValueError for invalid inputs.""" + policy = MagicMock(spec=BillingDisabledPolicy) + now = datetime.datetime.now() + + with pytest.raises(ValueError, match="start_from .* must be less than end_before"): + MessagesCleanService.from_time_range(policy, now, now) + + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_time_range(policy, now - datetime.timedelta(days=1), now, batch_size=0) + + def test_from_time_range_success(self): + """Test that from_time_range creates a service with correct parameters.""" + policy = MagicMock(spec=BillingDisabledPolicy) + start = datetime.datetime(2024, 1, 1) + end = datetime.datetime(2024, 2, 1) + + service = MessagesCleanService.from_time_range(policy, start, end) + assert service._start_from == start + assert service._end_before == end + + def test_from_days_validation(self): + """Test that from_days raises ValueError for invalid inputs.""" + policy = MagicMock(spec=BillingDisabledPolicy) + + with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): + MessagesCleanService.from_days(policy, days=-1) + + with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): + MessagesCleanService.from_days(policy, days=30, batch_size=0) + + def test_from_days_success(self): + """Test that from_days creates a service with correct parameters.""" + policy = MagicMock(spec=BillingDisabledPolicy) + + with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: + fixed_now = datetime.datetime(2024, 6, 1) + mock_now.return_value = fixed_now + + service = MessagesCleanService.from_days(policy, days=10) + assert service._start_from is None + assert service._end_before == fixed_now - datetime.timedelta(days=10) + + def test_batch_delete_message_relations_empty(self, db_session_with_containers: Session): + """Test that batch_delete_message_relations with empty list does nothing.""" + # Get execute call count before + MessagesCleanService._batch_delete_message_relations(db_session_with_containers, []) + # No exception means success — empty list is a no-op + + def test_run_calls_clean_messages(self): + """Test that run() delegates to _clean_messages_by_time_range.""" + policy = MagicMock(spec=BillingDisabledPolicy) + service = MessagesCleanService( + policy=policy, + end_before=datetime.datetime.now(), + batch_size=10, + ) + with patch.object(service, "_clean_messages_by_time_range") as mock_clean: + mock_clean.return_value = {"total_deleted": 5} + result = service.run() + assert result == {"total_deleted": 5} + mock_clean.assert_called_once() diff --git a/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py new file mode 100644 index 0000000000..b55a19eaa9 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_metadata_partial_update.py @@ -0,0 +1,182 @@ +from __future__ import annotations + +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest +from sqlalchemy import select + +from models.dataset import Dataset, DatasetMetadataBinding, Document +from models.enums import DataSourceType, DocumentCreatedFrom +from services.entities.knowledge_entities.knowledge_entities import ( + DocumentMetadataOperation, + MetadataDetail, + MetadataOperationData, +) +from services.metadata_service import MetadataService + + +def _create_dataset(db_session, *, tenant_id: str, built_in_field_enabled: bool = False) -> Dataset: + dataset = Dataset( + tenant_id=tenant_id, + name=f"dataset-{uuid4()}", + data_source_type=DataSourceType.UPLOAD_FILE, + created_by=str(uuid4()), + ) + dataset.id = str(uuid4()) + dataset.built_in_field_enabled = built_in_field_enabled + db_session.add(dataset) + db_session.commit() + return dataset + + +def _create_document(db_session, *, dataset_id: str, tenant_id: str, doc_metadata: dict | None = None) -> Document: + document = Document( + tenant_id=tenant_id, + dataset_id=dataset_id, + position=1, + data_source_type=DataSourceType.UPLOAD_FILE, + data_source_info="{}", + batch=f"batch-{uuid4()}", + name=f"doc-{uuid4()}", + created_from=DocumentCreatedFrom.WEB, + created_by=str(uuid4()), + ) + document.id = str(uuid4()) + document.doc_metadata = doc_metadata + db_session.add(document) + db_session.commit() + return document + + +class TestMetadataPartialUpdate: + @pytest.fixture + def tenant_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def user_id(self) -> str: + return str(uuid4()) + + @pytest.fixture + def mock_current_account(self, user_id, tenant_id): + account = Mock(id=user_id, current_tenant_id=tenant_id) + with patch("services.metadata_service.current_account_with_tenant", return_value=(account, tenant_id)): + yield account + + def test_partial_update_merges_metadata( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + updated_doc = db_session_with_containers.get(Document, document.id) + assert updated_doc is not None + assert updated_doc.doc_metadata["existing_key"] == "existing_value" + assert updated_doc.doc_metadata["new_key"] == "new_value" + + def test_full_update_replaces_metadata( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="new_key", value="new_value")], + partial_update=False, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + updated_doc = db_session_with_containers.get(Document, document.id) + assert updated_doc is not None + assert updated_doc.doc_metadata == {"new_key": "new_value"} + assert "existing_key" not in updated_doc.doc_metadata + + def test_partial_update_skips_existing_binding( + self, flask_app_with_containers, db_session_with_containers, tenant_id, user_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + existing_binding = DatasetMetadataBinding( + tenant_id=tenant_id, + dataset_id=dataset.id, + document_id=document.id, + metadata_id=meta_id, + created_by=user_id, + ) + db_session_with_containers.add(existing_binding) + db_session_with_containers.commit() + + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="existing_key", value="existing_value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + MetadataService.update_documents_metadata(dataset, metadata_args) + db_session_with_containers.expire_all() + + bindings = db_session_with_containers.scalars( + select(DatasetMetadataBinding).where( + DatasetMetadataBinding.document_id == document.id, + DatasetMetadataBinding.metadata_id == meta_id, + ) + ).all() + assert len(bindings) == 1 + + def test_rollback_called_on_commit_failure( + self, flask_app_with_containers, db_session_with_containers, tenant_id, mock_current_account + ): + dataset = _create_dataset(db_session_with_containers, tenant_id=tenant_id) + document = _create_document( + db_session_with_containers, + dataset_id=dataset.id, + tenant_id=tenant_id, + doc_metadata={"existing_key": "existing_value"}, + ) + + meta_id = str(uuid4()) + operation = DocumentMetadataOperation( + document_id=document.id, + metadata_list=[MetadataDetail(id=meta_id, name="key", value="value")], + partial_update=True, + ) + metadata_args = MetadataOperationData(operation_data=[operation]) + + with patch("services.metadata_service.db.session.commit", side_effect=RuntimeError("database connection lost")): + with pytest.raises(RuntimeError, match="database connection lost"): + MetadataService.update_documents_metadata(dataset, metadata_args) diff --git a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py index 989df42499..ca6e7afeab 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_load_balancing_service.py @@ -18,11 +18,10 @@ class TestModelLoadBalancingService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_load_balancing_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch( - "services.model_load_balancing_service.ModelProviderFactory", autospec=True - ) as mock_model_provider_factory, + "services.model_load_balancing_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch("services.model_load_balancing_service.LBModelManager", autospec=True) as mock_lb_model_manager, patch("services.model_load_balancing_service.encrypter", autospec=True) as mock_encrypter, ): # Setup default mock returns @@ -46,9 +45,6 @@ class TestModelLoadBalancingService: # Mock LBModelManager mock_lb_model_manager.get_config_in_cooldown_and_ttl.return_value = (False, 0) - # Mock ModelProviderFactory - mock_model_provider_factory_instance = mock_model_provider_factory.return_value - # Mock credential schemas mock_credential_schema = MagicMock() mock_credential_schema.credential_form_schemas = [] @@ -61,7 +57,6 @@ class TestModelLoadBalancingService: yield { "provider_manager": mock_provider_manager, "lb_model_manager": mock_lb_model_manager, - "model_provider_factory": mock_model_provider_factory, "encrypter": mock_encrypter, "provider_config": mock_provider_config, "provider_model_setting": mock_provider_model_setting, diff --git a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py index 6afc5aa43c..ba926bf675 100644 --- a/api/tests/test_containers_integration_tests/services/test_model_provider_service.py +++ b/api/tests/test_containers_integration_tests/services/test_model_provider_service.py @@ -2,10 +2,10 @@ from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType from sqlalchemy.orm import Session from core.entities.model_entities import ModelStatus -from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType from models import Account, Tenant, TenantAccountJoin, TenantAccountRole from models.provider import Provider, ProviderModel, ProviderModelSetting, ProviderType from services.model_provider_service import ModelProviderService @@ -18,8 +18,12 @@ class TestModelProviderService: def mock_external_service_dependencies(self): """Mock setup for external service dependencies.""" with ( - patch("services.model_provider_service.ProviderManager", autospec=True) as mock_provider_manager, - patch("services.model_provider_service.ModelProviderFactory", autospec=True) as mock_model_provider_factory, + patch( + "services.model_provider_service.create_plugin_provider_manager", autospec=True + ) as mock_provider_manager, + patch( + "services.model_provider_service.create_plugin_model_provider_factory", autospec=True + ) as mock_model_provider_factory, ): # Setup default mock returns mock_provider_manager.return_value.get_configurations.return_value = MagicMock() @@ -401,9 +405,10 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock models + from graphon.model_runtime.entities.common_entities import I18nObject + from graphon.model_runtime.entities.provider_entities import ProviderEntity + from core.entities.model_entities import ModelWithProviderEntity, SimpleModelProviderEntity - from dify_graph.model_runtime.entities.common_entities import I18nObject - from dify_graph.model_runtime.entities.provider_entities import ProviderEntity # Create real model objects instead of mocks provider_entity_1 = SimpleModelProviderEntity( @@ -639,8 +644,9 @@ class TestModelProviderService: mock_provider_manager = mock_external_service_dependencies["provider_manager"].return_value # Create mock default model response + from graphon.model_runtime.entities.common_entities import I18nObject + from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity - from dify_graph.model_runtime.entities.common_entities import I18nObject mock_default_model = DefaultModelEntity( model="gpt-3.5-turbo", diff --git a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py index ba4310e22e..7036524918 100644 --- a/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py +++ b/api/tests/test_containers_integration_tests/services/test_restore_archived_workflow_run.py @@ -2,17 +2,43 @@ Testcontainers integration tests for workflow run restore functionality. """ +from __future__ import annotations + +from datetime import datetime from uuid import uuid4 from sqlalchemy import select -from models.workflow import WorkflowPause +from models.workflow import WorkflowPause, WorkflowRun from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore class TestWorkflowRunRestore: """Tests for the WorkflowRunRestore class.""" + def test_restore_initialization(self): + """Restore service should respect dry_run flag.""" + restore = WorkflowRunRestore(dry_run=True) + + assert restore.dry_run is True + + def test_convert_datetime_fields(self): + """ISO datetime strings should be converted to datetime objects.""" + record = { + "id": "test-id", + "created_at": "2024-01-01T12:00:00", + "finished_at": "2024-01-01T12:05:00", + "name": "test", + } + + restore = WorkflowRunRestore() + result = restore._convert_datetime_fields(record, WorkflowRun) + + assert isinstance(result["created_at"], datetime) + assert result["created_at"].year == 2024 + assert result["created_at"].month == 1 + assert result["name"] == "test" + def test_restore_table_records_returns_rowcount(self, db_session_with_containers): """Restore should return inserted rowcount.""" restore = WorkflowRunRestore() diff --git a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py index d256c0d90b..70aa813142 100644 --- a/api/tests/test_containers_integration_tests/services/test_saved_message_service.py +++ b/api/tests/test_containers_integration_tests/services/test_saved_message_service.py @@ -20,7 +20,7 @@ class TestSavedMessageService: """Mock setup for external service dependencies.""" with ( patch("services.account_service.FeatureService") as mock_account_feature_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.saved_message_service.MessageService") as mock_message_service, ): # Setup default mock returns diff --git a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py index 6b95954480..f2307fbd7d 100644 --- a/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py +++ b/api/tests/test_containers_integration_tests/services/test_web_conversation_service.py @@ -25,7 +25,7 @@ class TestWebConversationService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py index 880143013e..749c6fff5b 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_app_service.py @@ -8,9 +8,9 @@ from unittest.mock import patch import pytest from faker import Faker +from graphon.enums import WorkflowExecutionStatus from sqlalchemy.orm import Session -from dify_graph.entities.workflow_execution import WorkflowExecutionStatus from models import EndUser, Workflow, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom from models.workflow import WorkflowAppLogCreatedFrom @@ -31,7 +31,7 @@ class TestWorkflowAppService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py index 572cf72fa0..0c281c8c33 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_draft_variable_service.py @@ -1,9 +1,9 @@ import pytest from faker import Faker +from graphon.variables.segments import StringSegment from sqlalchemy.orm import Session -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.segments import StringSegment +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from models import App, Workflow from models.enums import DraftVariableType from models.workflow import WorkflowDraftVariable @@ -482,7 +482,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from dify_graph.variables.variables import StringVariable + from graphon.variables.variables import StringVariable conv_var = StringVariable( id=fake.uuid4(), @@ -734,7 +734,7 @@ class TestWorkflowDraftVariableService: fake = Faker() app = self._create_test_app(db_session_with_containers, mock_external_service_dependencies, fake=fake) workflow = self._create_test_workflow(db_session_with_containers, app, fake=fake) - from dify_graph.variables.variables import StringVariable + from graphon.variables.variables import StringVariable conv_var1 = StringVariable( id=fake.uuid4(), diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py index 731770e01a..d02a078281 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_run_service.py @@ -27,7 +27,7 @@ class TestWorkflowRunService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, ): # Setup default mock returns for app service diff --git a/api/tests/test_containers_integration_tests/services/test_workflow_service.py b/api/tests/test_containers_integration_tests/services/test_workflow_service.py index a5fe052206..b5ce8a53de 100644 --- a/api/tests/test_containers_integration_tests/services/test_workflow_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workflow_service.py @@ -555,6 +555,124 @@ class TestWorkflowService: assert len(result_workflows) == 2 assert all(wf.marked_name for wf in result_workflows) + def test_get_all_published_workflow_no_workflow_id(self, db_session_with_containers: Session): + """Test that an app with no workflow_id returns empty results.""" + # Arrange + fake = Faker() + app = self._create_test_app(db_session_with_containers, fake) + app.workflow_id = None + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None + ) + + # Assert + assert result_workflows == [] + assert has_more is False + + def test_get_all_published_workflow_basic(self, db_session_with_containers: Session): + """Test basic retrieval of published workflows.""" + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + workflow1 = self._create_test_workflow(db_session_with_containers, app, account, fake) + workflow1.version = "2024.01.01.001" + workflow2 = self._create_test_workflow(db_session_with_containers, app, account, fake) + workflow2.version = "2024.01.02.001" + + app.workflow_id = workflow1.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, app_model=app, page=1, limit=10, user_id=None + ) + + # Assert + assert len(result_workflows) == 2 + assert has_more is False + + def test_get_all_published_workflow_combined_filters(self, db_session_with_containers: Session): + """Test combined user_id and named_only filters.""" + # Arrange + fake = Faker() + account1 = self._create_test_account(db_session_with_containers, fake) + account2 = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + # account1 named + wf1 = self._create_test_workflow(db_session_with_containers, app, account1, fake) + wf1.version = "2024.01.01.001" + wf1.marked_name = "Named by user1" + wf1.created_by = account1.id + + # account1 unnamed + wf2 = self._create_test_workflow(db_session_with_containers, app, account1, fake) + wf2.version = "2024.01.02.001" + wf2.marked_name = "" + wf2.created_by = account1.id + + # account2 named + wf3 = self._create_test_workflow(db_session_with_containers, app, account2, fake) + wf3.version = "2024.01.03.001" + wf3.marked_name = "Named by user2" + wf3.created_by = account2.id + + app.workflow_id = wf1.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act - Filter by account1 + named_only + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, + app_model=app, + page=1, + limit=10, + user_id=account1.id, + named_only=True, + ) + + # Assert - Only wf1 matches (account1 + named) + assert len(result_workflows) == 1 + assert result_workflows[0].marked_name == "Named by user1" + assert result_workflows[0].created_by == account1.id + + def test_get_all_published_workflow_empty_result(self, db_session_with_containers: Session): + """Test that querying with no matching workflows returns empty.""" + # Arrange + fake = Faker() + account = self._create_test_account(db_session_with_containers, fake) + app = self._create_test_app(db_session_with_containers, fake) + + # Create a draft workflow (no version set = draft) + workflow = self._create_test_workflow(db_session_with_containers, app, account, fake) + app.workflow_id = workflow.id + db_session_with_containers.commit() + + workflow_service = WorkflowService() + + # Act - Filter by a user that has no workflows + result_workflows, has_more = workflow_service.get_all_published_workflow( + session=db_session_with_containers, + app_model=app, + page=1, + limit=10, + user_id="00000000-0000-0000-0000-000000000000", + ) + + # Assert + assert result_workflows == [] + assert has_more is False + def test_sync_draft_workflow_create_new(self, db_session_with_containers: Session): """ Test creating a new draft workflow through sync operation. @@ -1503,10 +1621,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunSucceededEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1548,12 +1666,12 @@ class TestWorkflowService: # Assert assert result is not None assert result.node_id == node_id - from dify_graph.enums import BuiltinNodeTypes + from graphon.enums import BuiltinNodeTypes assert result.node_type == BuiltinNodeTypes.START # Should match the mock node type assert result.title == "Test Node" # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED assert result.inputs is not None @@ -1578,10 +1696,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunFailedEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunFailedEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node mock_node = MagicMock(spec=Node) @@ -1623,7 +1741,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.FAILED assert result.error is not None @@ -1647,10 +1765,10 @@ class TestWorkflowService: import uuid from datetime import datetime - from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus - from dify_graph.graph_events import NodeRunFailedEvent - from dify_graph.node_events import NodeRunResult - from dify_graph.nodes.base.node import Node + from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus + from graphon.graph_events import NodeRunFailedEvent + from graphon.node_events import NodeRunResult + from graphon.nodes.base.node import Node # Create mock node with continue_on_error mock_node = MagicMock(spec=Node) @@ -1693,7 +1811,7 @@ class TestWorkflowService: assert result is not None assert result.node_id == node_id # Import the enum for comparison - from dify_graph.enums import WorkflowNodeExecutionStatus + from graphon.enums import WorkflowNodeExecutionStatus assert result.status == WorkflowNodeExecutionStatus.EXCEPTION # Should be EXCEPTION, not FAILED assert result.outputs is not None diff --git a/api/tests/test_containers_integration_tests/services/test_workspace_service.py b/api/tests/test_containers_integration_tests/services/test_workspace_service.py index 92dec24c7d..4e89d906f1 100644 --- a/api/tests/test_containers_integration_tests/services/test_workspace_service.py +++ b/api/tests/test_containers_integration_tests/services/test_workspace_service.py @@ -1,4 +1,6 @@ -from unittest.mock import patch +from __future__ import annotations + +from unittest.mock import MagicMock, patch import pytest from faker import Faker @@ -534,3 +536,283 @@ class TestWorkspaceService: # Verify database state db_session_with_containers.refresh(tenant) assert tenant.id is not None + + def test_get_tenant_info_should_raise_assertion_when_join_missing( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """TenantAccountJoin must exist; missing join should raise AssertionError.""" + fake = Faker() + account = Account(email=fake.email(), name=fake.name(), interface_language="en-US", status="active") + db_session_with_containers.add(account) + db_session_with_containers.commit() + + tenant = Tenant(name=fake.company(), status="normal", plan="basic") + db_session_with_containers.add(tenant) + db_session_with_containers.commit() + + # No TenantAccountJoin created + with patch("services.workspace_service.current_user", account): + with pytest.raises(AssertionError, match="TenantAccountJoin not found"): + WorkspaceService.get_tenant_info(tenant) + + def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """replace_webapp_logo should be None when custom_config_dict does not have the key.""" + import json + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tenant.custom_config = json.dumps({}) + db_session_with_containers.commit() + + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True + mock_external_service_dependencies["tenant_service"].has_roles.return_value = True + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["custom_config"]["replace_webapp_logo"] is None + + def test_get_tenant_info_should_use_files_url_for_logo_url( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """The logo URL should use dify_config.FILES_URL as the base.""" + import json + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + tenant.custom_config = json.dumps({"replace_webapp_logo": True}) + db_session_with_containers.commit() + + custom_base = "https://cdn.mycompany.io" + mock_external_service_dependencies["dify_config"].FILES_URL = custom_base + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True + mock_external_service_dependencies["tenant_service"].has_roles.return_value = True + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) + + def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "SELF_HOSTED" + mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with patch("services.workspace_service.current_user", account): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert "next_credit_reset_date" not in result + assert "trial_credits" not in result + assert "trial_credits_used" not in result + + def test_get_tenant_info_cloud_credit_reset_date( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """next_credit_reset_date should be present in CLOUD edition.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=None), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["next_credit_reset_date"] == "2025-02-01" + + def test_get_tenant_info_cloud_paid_pool_not_full( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """trial_credits come from paid pool when plan is not sandbox and pool is not full.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=1000, quota_used=200) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=paid_pool), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 1000 + assert result["trial_credits_used"] == 200 + + def test_get_tenant_info_cloud_paid_pool_unlimited( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """quota_limit == -1 means unlimited; service should use paid pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=-1, quota_used=999) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, None]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == -1 + assert result["trial_credits_used"] == 999 + + def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_full( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When paid pool is exhausted, switch to trial pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=500, quota_used=500) + trial_pool = MagicMock(quota_limit=100, quota_used=10) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 100 + assert result["trial_credits_used"] == 10 + + def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When paid_pool is None, fall back to trial pool.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + trial_pool = MagicMock(quota_limit=50, quota_used=5) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 50 + assert result["trial_credits_used"] == 5 + + def test_get_tenant_info_cloud_sandbox_uses_trial_pool( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When plan is SANDBOX, skip paid pool and use trial pool.""" + from enums.cloud_plan import CloudPlan + + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = CloudPlan.SANDBOX + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + paid_pool = MagicMock(quota_limit=1000, quota_used=0) + trial_pool = MagicMock(quota_limit=200, quota_used=20) + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert result["trial_credits"] == 200 + assert result["trial_credits_used"] == 20 + + def test_get_tenant_info_cloud_both_pools_none( + self, db_session_with_containers: Session, mock_external_service_dependencies + ): + """When both paid and trial pools are absent, trial_credits should not be set.""" + fake = Faker() + account, tenant = self._create_test_account_and_tenant( + db_session_with_containers, mock_external_service_dependencies + ) + + mock_external_service_dependencies["dify_config"].EDITION = "CLOUD" + feature = mock_external_service_dependencies["feature_service"].get_features.return_value + feature.can_replace_logo = False + feature.next_credit_reset_date = "2025-02-01" + feature.billing.subscription.plan = "professional" + mock_external_service_dependencies["tenant_service"].has_roles.return_value = False + + with ( + patch("services.workspace_service.current_user", account), + patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, None]), + ): + result = WorkspaceService.get_tenant_info(tenant) + + assert result is not None + assert "trial_credits" not in result + assert "trial_credits_used" not in result diff --git a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py index e3c0749494..21a1975879 100644 --- a/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py +++ b/api/tests/test_containers_integration_tests/services/tools/test_workflow_tools_manage_service.py @@ -25,7 +25,7 @@ class TestWorkflowToolManageService: with ( patch("services.app_service.FeatureService") as mock_feature_service, patch("services.app_service.EnterpriseService") as mock_enterprise_service, - patch("services.app_service.ModelManager") as mock_model_manager, + patch("services.app_service.ModelManager.for_tenant") as mock_model_manager, patch("services.account_service.FeatureService") as mock_account_feature_service, patch( "services.tools.workflow_tools_manage_service.WorkflowToolProviderController" diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py index c3fe6a2950..ce2fd2eeb1 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_converter.py @@ -1,11 +1,19 @@ +from __future__ import annotations + import json -from unittest.mock import patch +from unittest.mock import MagicMock, patch import pytest from faker import Faker +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.variables.input_entities import VariableEntity, VariableEntityType from sqlalchemy.orm import Session from core.app.app_config.entities import ( + AdvancedChatMessageEntity, + AdvancedChatPromptTemplateEntity, + AdvancedCompletionPromptTemplateEntity, DatasetEntity, DatasetRetrieveConfigEntity, ExternalDataVariableEntity, @@ -13,10 +21,8 @@ from core.app.app_config.entities import ( PromptTemplateEntity, ) from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models import Account, Tenant -from models.api_based_extension import APIBasedExtension +from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint from models.model import App, AppMode, AppModelConfig from models.workflow import Workflow from services.workflow.workflow_converter import WorkflowConverter @@ -548,3 +554,198 @@ class TestWorkflowConverter: # Verify single retrieval config is None for multiple strategy assert node["data"]["single_retrieval_config"] is None + + +@pytest.fixture +def default_variables(): + return [ + VariableEntity(variable="text_input", label="text-input", type=VariableEntityType.TEXT_INPUT), + VariableEntity(variable="paragraph", label="paragraph", type=VariableEntityType.PARAGRAPH), + VariableEntity(variable="select", label="select", type=VariableEntityType.SELECT), + ] + + +class TestConvertToHttpRequestNodeVariants: + """Tests for chatbot vs workflow differences in HTTP request node conversion.""" + + @staticmethod + def _setup(app_mode, default_variables): + app_model = App( + tenant_id="tenant_id", + mode=app_mode, + name="test", + icon_type="emoji", + icon="🤖", + icon_background="#FFFFFF", + ) + + ext = APIBasedExtension(tenant_id="tenant_id", name="api-1", api_key="enc", api_endpoint="https://dify.ai") + ext.id = "ext_id" + + converter = WorkflowConverter() + converter._get_api_based_extension = MagicMock(return_value=ext) + + from core.helper import encrypter + + encrypter.decrypt_token = MagicMock(return_value="api_key") + + ext_vars = [ + ExternalDataVariableEntity( + variable="external_variable", type="api", config={"api_based_extension_id": "ext_id"} + ) + ] + nodes, _ = converter._convert_to_http_request_node( + app_model=app_model, + variables=default_variables, + external_data_variables=ext_vars, + ) + return nodes + + def test_chatbot_query_uses_sys_query(self, default_variables): + nodes = self._setup(AppMode.CHAT, default_variables) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "{{#sys.query#}}" + assert body["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY + assert nodes[1]["data"]["type"] == "code" + + def test_workflow_query_is_empty(self, default_variables): + nodes = self._setup(AppMode.WORKFLOW, default_variables) + + body = json.loads(nodes[0]["data"]["body"]["data"]) + assert body["params"]["query"] == "" + + +class TestConvertToKnowledgeRetrievalNodeVariants: + """Tests for chatbot vs workflow differences in knowledge retrieval node.""" + + @staticmethod + def _dataset_config(query_variable=None): + return DatasetEntity( + dataset_ids=["ds1", "ds2"], + retrieve_config=DatasetRetrieveConfigEntity( + query_variable=query_variable, + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, + top_k=5, + score_threshold=0.8, + reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, + reranking_enabled=True, + ), + ) + + @staticmethod + def _model_config(): + return ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) + + def test_chatbot_uses_sys_query(self): + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.ADVANCED_CHAT, + dataset_config=self._dataset_config(), + model_config=self._model_config(), + ) + assert node["data"]["query_variable_selector"] == ["sys", "query"] + + def test_workflow_uses_start_variable(self): + node = WorkflowConverter()._convert_to_knowledge_retrieval_node( + new_app_mode=AppMode.WORKFLOW, + dataset_config=self._dataset_config(query_variable="query"), + model_config=self._model_config(), + ) + assert node["data"]["query_variable_selector"] == ["start", "query"] + + +class TestConvertToLlmNode: + """Tests for LLM node conversion across model modes and prompt types.""" + + @staticmethod + def _model_config(model, mode): + return ModelConfigEntity( + provider="openai", + model=model, + mode=mode.value, + parameters={}, + stop=[], + ) + + @staticmethod + def _graph(default_variables): + start = WorkflowConverter()._convert_to_start_node(default_variables) + return {"nodes": [start], "edges": []} + + def test_simple_chat_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are helpful {{text_input}}, {{paragraph}}, {{select}}.", + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-4", LLMMode.CHAT), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert node["data"]["type"] == "llm" + assert node["data"]["model"]["mode"] == LLMMode.CHAT.value + assert node["data"]["context"]["enabled"] is False + expected = "You are helpful {{#start.text_input#}}, {{#start.paragraph#}}, {{#start.select#}}.\n" + assert node["data"]["prompt_template"][0]["text"] == expected + + def test_simple_completion_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.SIMPLE, + simple_prompt_template="You are helpful {{text_input}}, {{paragraph}}, {{select}}.", + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-3.5-turbo-instruct", LLMMode.COMPLETION), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert node["data"]["model"]["mode"] == LLMMode.COMPLETION.value + expected = "You are helpful {{#start.text_input#}}, {{#start.paragraph#}}, {{#start.select#}}.\n" + assert node["data"]["prompt_template"]["text"] == expected + + def test_advanced_chat_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( + messages=[ + AdvancedChatMessageEntity( + text="You are helpful named {{name}}.\n\nContext:\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + ), + AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), + AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), + ] + ), + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-4", LLMMode.CHAT), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert isinstance(node["data"]["prompt_template"], list) + assert len(node["data"]["prompt_template"]) == 3 + + def test_advanced_completion_model(self, default_variables): + prompt = PromptTemplateEntity( + prompt_type=PromptTemplateEntity.PromptType.ADVANCED, + advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( + prompt="You are helpful named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ", + role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity( + user="Human", assistant="Assistant" + ), + ), + ) + node = WorkflowConverter()._convert_to_llm_node( + original_app_mode=AppMode.CHAT, + new_app_mode=AppMode.ADVANCED_CHAT, + model_config=self._model_config("gpt-3.5-turbo-instruct", LLMMode.COMPLETION), + graph=self._graph(default_variables), + prompt_template=prompt, + ) + assert isinstance(node["data"]["prompt_template"], dict) + assert "text" in node["data"]["prompt_template"] diff --git a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py index af9e8d0b2c..7c43bf676b 100644 --- a/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py +++ b/api/tests/test_containers_integration_tests/services/workflow/test_workflow_node_execution_service_repository.py @@ -1,10 +1,10 @@ from datetime import datetime, timedelta from uuid import uuid4 +from graphon.enums import WorkflowNodeExecutionStatus from sqlalchemy import Engine, select from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowNodeExecutionStatus from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.workflow import WorkflowNodeExecutionModel diff --git a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py index d2e343ef52..f9ae33b32f 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_batch_create_segment_to_index_task.py @@ -54,7 +54,10 @@ class TestBatchCreateSegmentToIndexTask: """Mock setup for external service dependencies.""" with ( patch("tasks.batch_create_segment_to_index_task.storage", autospec=True) as mock_storage, - patch("tasks.batch_create_segment_to_index_task.ModelManager", autospec=True) as mock_model_manager, + patch( + "tasks.batch_create_segment_to_index_task.ModelManager.for_tenant", + autospec=True, + ) as mock_model_manager, patch("tasks.batch_create_segment_to_index_task.VectorService", autospec=True) as mock_vector_service, ): # Setup default mock returns diff --git a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py index 0876a39f82..a16f3ff773 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_mail_human_input_delivery_task.py @@ -3,22 +3,22 @@ from datetime import UTC, datetime from unittest.mock import patch import pytest +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import HumanInputNodeData +from graphon.runtime import GraphRuntimeState, VariablePool from configs import dify_config from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext from core.repositories.human_input_repository import FormCreateParams, HumanInputFormRepositoryImpl -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool from extensions.ext_storage import storage from models.account import Account, AccountStatus, Tenant, TenantAccountJoin, TenantAccountRole from models.enums import CreatorUserRole, WorkflowRunTriggeredFrom @@ -79,9 +79,9 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id=account.id), + MemberRecipient(reference_id=account.id), ExternalRecipient(email="external@example.com"), ], ), @@ -96,9 +96,8 @@ def _build_form(db_session_with_containers, tenant, account, *, app_id: str, wor delivery_methods=[delivery_method], ) - repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id) + repo = HumanInputFormRepositoryImpl(tenant_id=tenant.id, app_id=app_id) params = FormCreateParams( - app_id=app_id, workflow_execution_id=workflow_execution_id, node_id="node-1", form_config=node_data, diff --git a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py index 5bded4d670..96cf9cebf5 100644 --- a/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py +++ b/api/tests/test_containers_integration_tests/tasks/test_remove_app_and_related_data_task.py @@ -2,10 +2,10 @@ import uuid from unittest.mock import ANY, call, patch import pytest +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from core.db.session_factory import session_factory -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType from extensions.storage.storage_type import StorageType from libs.datetime_utils import naive_utc_now from models import Tenant diff --git a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py index ca76fa0a4b..159ab51304 100644 --- a/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py +++ b/api/tests/test_containers_integration_tests/test_workflow_pause_integration.py @@ -24,11 +24,11 @@ from dataclasses import dataclass from datetime import timedelta import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import delete, select from sqlalchemy.orm import Session, selectinload, sessionmaker -from dify_graph.entities import WorkflowExecution -from dify_graph.enums import WorkflowExecutionStatus from extensions.ext_storage import storage from libs.datetime_utils import naive_utc_now from models import Account diff --git a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py index 4ea8d8c1c7..7539bae685 100644 --- a/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py +++ b/api/tests/test_containers_integration_tests/trigger/test_trigger_e2e.py @@ -10,6 +10,7 @@ from typing import Any import pytest from flask import Flask, Response from flask.testing import FlaskClient +from graphon.enums import BuiltinNodeTypes from sqlalchemy.orm import Session from configs import dify_config @@ -23,7 +24,6 @@ from core.trigger.debug import event_selectors from core.trigger.debug.event_bus import TriggerDebugEventBus from core.trigger.debug.event_selectors import PluginTriggerDebugEventPoller, WebhookTriggerDebugEventPoller from core.trigger.debug.events import PluginTriggerDebugEvent, build_plugin_pool_key -from dify_graph.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models.account import Account, Tenant from models.enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTriggerStatus diff --git a/api/tests/unit_tests/conftest.py b/api/tests/unit_tests/conftest.py index 3f75fd2851..55873b06a8 100644 --- a/api/tests/unit_tests/conftest.py +++ b/api/tests/unit_tests/conftest.py @@ -123,27 +123,26 @@ def _configure_session_factory(_unit_test_engine): def setup_mock_tenant_account_query(mock_db, mock_tenant, mock_account): """ - Helper to set up the mock DB query chain for tenant/account authentication. + Helper to set up the mock DB execute chain for tenant/account authentication. - This configures the mock to return (tenant, account) for the join query used - by validate_app_token and validate_dataset_token decorators. + This configures the mock to return (tenant, account) for the + db.session.execute(select(...).join().join().where()).one_or_none() + query used by validate_app_token decorator. Args: mock_db: The mocked db object mock_tenant: Mock tenant object to return mock_account: Mock account object to return """ - query = mock_db.session.query.return_value - join_chain = query.join.return_value.join.return_value - where_chain = join_chain.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_account) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_account) def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): """ - Helper to set up the mock DB query chain for dataset tenant authentication. + Helper to set up the mock DB execute chain for dataset tenant authentication. - This configures the mock to return (tenant, tenant_account) for the where chain + This configures the mock to return (tenant, tenant_account) for the + db.session.execute(select(...).where().where().where().where()).one_or_none() query used by validate_dataset_token decorator. Args: @@ -151,6 +150,4 @@ def setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta): mock_tenant: Mock tenant object to return mock_ta: Mock tenant account object to return """ - query = mock_db.session.query.return_value - where_chain = query.where.return_value.where.return_value.where.return_value.where.return_value - where_chain.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) diff --git a/api/tests/unit_tests/controllers/console/app/test_audio.py b/api/tests/unit_tests/controllers/console/app/test_audio.py index 021e9a0784..c52bc02420 100644 --- a/api/tests/unit_tests/controllers/console/app/test_audio.py +++ b/api/tests/unit_tests/controllers/console/app/test_audio.py @@ -4,6 +4,7 @@ import io from types import SimpleNamespace import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -20,7 +21,6 @@ from controllers.console.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow.py b/api/tests/unit_tests/controllers/console/app/test_workflow.py index 0e22db9f9b..3607636880 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow.py @@ -5,12 +5,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import File, FileTransferMethod, FileType from werkzeug.exceptions import HTTPException, NotFound from controllers.console.app import workflow as workflow_module from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File def _unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py index 83601dc1b9..e11102acb1 100644 --- a/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py +++ b/api/tests/unit_tests/controllers/console/app/test_workflow_pause_details_api.py @@ -6,14 +6,14 @@ from unittest.mock import Mock import pytest from flask import Flask +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from controllers.console import wraps as console_wraps from controllers.console.app import workflow_run as workflow_run_module from controllers.web.error import NotFoundError -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType from libs import login as login_lib from models.account import Account, AccountStatus, TenantAccountRole from models.workflow import WorkflowRun @@ -67,7 +67,6 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte actions=[UserAction(id="approve", title="Approve")], node_id="node-1", node_title="Ask Name", - form_token="backstage-token", ) pause_entity = _PauseEntity(paused_at=datetime(2024, 1, 1, 12, 0, 0), reasons=[reason]) @@ -78,6 +77,11 @@ def test_pause_details_returns_backstage_input_url(app: Flask, monkeypatch: pyte "create_api_workflow_run_repository", lambda *_, **__: repo, ) + monkeypatch.setattr( + workflow_run_module, + "_load_form_tokens_by_form_id", + lambda _form_ids: {"form-1": "backstage-token"}, + ) with app.test_request_context("/console/api/workflow/run-1/pause-details", method="GET"): response, status = workflow_run_module.ConsoleWorkflowPauseDetailsApi().get(workflow_run_id="run-1") diff --git a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py index f34702a257..740da1f1df 100644 --- a/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py +++ b/api/tests/unit_tests/controllers/console/app/workflow_draft_variables_test.py @@ -5,6 +5,7 @@ from unittest.mock import MagicMock, patch import pytest from flask_restx import marshal +from graphon.variables.types import SegmentType from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_FIELDS, @@ -13,8 +14,7 @@ from controllers.console.app.workflow_draft_variable import ( _WORKFLOW_DRAFT_VARIABLE_WITHOUT_VALUE_FIELDS, _serialize_full_content, ) -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID from factories.variable_factory import build_segment from libs.datetime_utils import naive_utc_now from libs.uuid_utils import uuidv7 @@ -310,13 +310,11 @@ def test_workflow_node_variables_fields(): def test_workflow_file_variable_with_signed_url(): """Test that File type variables include signed URLs in API responses.""" - from dify_graph.file.enums import FileTransferMethod, FileType - from dify_graph.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with LOCAL_FILE transfer method (which generates signed URLs) test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="test_upload_file_id", @@ -368,13 +366,11 @@ def test_workflow_file_variable_with_signed_url(): def test_workflow_file_variable_remote_url(): """Test that File type variables with REMOTE_URL transfer method return the remote URL.""" - from dify_graph.file.enums import FileTransferMethod, FileType - from dify_graph.file.models import File + from graphon.file import File, FileTransferMethod, FileType # Create a File object with REMOTE_URL transfer method test_file = File( id="test_file_id", - tenant_id="test_tenant_id", type=FileType.IMAGE, transfer_method=FileTransferMethod.REMOTE_URL, remote_url="https://example.com/test.jpg", diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py index 9014edc39e..9c9f8da87c 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_datasource_auth.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden, NotFound from controllers.console import console_ns @@ -17,7 +18,6 @@ from controllers.console.datasets.rag_pipeline.datasource_auth import ( DatasourceUpdateProviderNameApi, ) from core.plugin.impl.oauth import OAuthHandler -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError from services.datasource_provider_service import DatasourceProviderService from services.plugin.oauth_service import OAuthProxyService diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py index b4c0903f63..6ef8ccfdbd 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_draft_variable.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Response +from graphon.variables.types import SegmentType from controllers.console import console_ns from controllers.console.app.error import DraftWorkflowNotExist @@ -14,8 +15,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_draft_variable impor RagPipelineVariableResetApi, ) from controllers.web.error import InvalidArgumentError, NotFoundError -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.variables.types import SegmentType +from core.workflow.variable_prefixes import SYSTEM_VARIABLE_NODE_ID from models.account import Account diff --git a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py index 472d133349..a3c0592d76 100644 --- a/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py +++ b/api/tests/unit_tests/controllers/console/datasets/rag_pipeline/test_rag_pipeline_workflow.py @@ -2,7 +2,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest -from werkzeug.exceptions import Forbidden, HTTPException, NotFound +from werkzeug.exceptions import BadRequest, Forbidden, HTTPException, NotFound import services from controllers.console import console_ns @@ -26,6 +26,7 @@ from controllers.console.datasets.rag_pipeline.rag_pipeline_workflow import ( RagPipelineWorkflowLastRunApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError +from libs.datetime_utils import naive_utc_now from services.errors.app import IsDraftWorkflowError, WorkflowHashNotEqualError, WorkflowNotFoundError from services.errors.llm import InvokeRateLimitError @@ -372,7 +373,7 @@ class TestPublishedPipelineApis: workflow = MagicMock( id="w1", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) session = MagicMock() @@ -691,6 +692,57 @@ class TestRagPipelineByIdApi: result, status = method(api, pipeline, "w1") assert status == 400 + def test_delete_success(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.delete) + + pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1") + + workflow_service = MagicMock() + + session = MagicMock() + session_ctx = MagicMock() + session_ctx.__enter__.return_value = session + session_ctx.__exit__.return_value = None + + fake_db = MagicMock() + fake_db.engine = MagicMock() + + with ( + app.test_request_context("/", method="DELETE"), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.db", + fake_db, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.Session", + return_value=session_ctx, + ), + patch( + "controllers.console.datasets.rag_pipeline.rag_pipeline_workflow.WorkflowService", + return_value=workflow_service, + ), + ): + result = method(api, pipeline, "old-workflow") + + workflow_service.delete_workflow.assert_called_once_with( + session=session, + workflow_id="old-workflow", + tenant_id="t1", + ) + session.commit.assert_called_once() + assert result == (None, 204) + + def test_delete_active_workflow_rejected(self, app): + api = RagPipelineByIdApi() + method = unwrap(api.delete) + + pipeline = MagicMock(tenant_id="t1", workflow_id="active-workflow", id="pipeline-1") + + with app.test_request_context("/", method="DELETE"): + with pytest.raises(BadRequest, match="currently in use by pipeline"): + method(api, pipeline, "active-workflow") + class TestRagPipelineWorkflowLastRunApi: def test_last_run_success(self, app): diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py index ff565f19fd..8555900f4e 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets.py @@ -417,7 +417,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding models exist → embedding_available stays True provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -521,7 +521,7 @@ class TestDatasetApiGet: "check_dataset_permission", return_value=None, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): # embedding model NOT configured provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] @@ -580,7 +580,7 @@ class TestDatasetApiGet: "get_dataset_partial_member_list", return_value=partial_members, ), - patch("controllers.console.datasets.datasets.ProviderManager") as provider_manager_mock, + patch("controllers.console.datasets.datasets.create_plugin_provider_manager") as provider_manager_mock, ): provider_manager_mock.return_value.get_configurations.return_value.get_models.return_value = [] diff --git a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py index 306a772fd1..693b06e95b 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_datasets_segments.py @@ -1,4 +1,3 @@ -from datetime import datetime from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -25,6 +24,7 @@ from controllers.console.datasets.error import ( ) from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.index_processor.constant.index_type import IndexStructureType +from libs.datetime_utils import naive_utc_now from models.dataset import ChildChunk, DocumentSegment from models.model import UploadFile @@ -54,8 +54,8 @@ def _segment(): disabled_by=None, status="normal", created_by="u1", - created_at=datetime.utcnow(), - updated_at=datetime.utcnow(), + created_at=naive_utc_now(), + updated_at=naive_utc_now(), updated_by="u1", indexing_at=None, completed_at=None, diff --git a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py index e7ae37ae45..710c9be684 100644 --- a/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py +++ b/api/tests/unit_tests/controllers/console/datasets/test_hit_testing_base.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import services @@ -20,7 +21,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from models.account import Account from services.dataset_service import DatasetService from services.hit_testing_service import HitTestingService diff --git a/api/tests/unit_tests/controllers/console/explore/test_audio.py b/api/tests/unit_tests/controllers/console/explore/test_audio.py index 0afbc5a8f7..66c9ba48c5 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_audio.py +++ b/api/tests/unit_tests/controllers/console/explore/test_audio.py @@ -2,6 +2,7 @@ from io import BytesIO from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError import controllers.console.explore.audio as audio_module @@ -19,7 +20,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_message.py b/api/tests/unit_tests/controllers/console/explore/test_message.py index 6b5c304884..2e4ca4f2a4 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_message.py +++ b/api/tests/unit_tests/controllers/console/explore/test_message.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import InternalServerError, NotFound import controllers.console.explore.message as module @@ -21,7 +22,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from services.errors.conversation import ConversationNotExistsError from services.errors.message import ( FirstMessageNotExistsError, diff --git a/api/tests/unit_tests/controllers/console/explore/test_trial.py b/api/tests/unit_tests/controllers/console/explore/test_trial.py index 5a03daecbc..04beb31389 100644 --- a/api/tests/unit_tests/controllers/console/explore/test_trial.py +++ b/api/tests/unit_tests/controllers/console/explore/test_trial.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.exceptions import Forbidden, InternalServerError, NotFound import controllers.console.explore.trial as module @@ -25,7 +26,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.errors.invoke import InvokeError from models import Account from models.account import TenantStatus from models.model import AppMode diff --git a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py index f2e57eb65f..9c42ee9529 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_load_balancing_config.py @@ -11,11 +11,10 @@ from unittest.mock import MagicMock import pytest from flask import Flask from flask.views import MethodView +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from werkzeug.exceptions import Forbidden -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError - if not hasattr(builtins, "MethodView"): builtins.MethodView = MethodView # type: ignore[attr-defined] diff --git a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py index af0c2c5594..fb9eec98cb 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_model_providers.py @@ -1,6 +1,7 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic_core import ValidationError from werkzeug.exceptions import Forbidden @@ -13,7 +14,6 @@ from controllers.console.workspace.model_providers import ( ModelProviderValidateApi, PreferredProviderTypeUpdateApi, ) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError VALID_UUID = "123e4567-e89b-12d3-a456-426614174000" INVALID_UUID = "123" diff --git a/api/tests/unit_tests/controllers/console/workspace/test_models.py b/api/tests/unit_tests/controllers/console/workspace/test_models.py index 43b8e1ac2e..c829327bc7 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_models.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_models.py @@ -2,6 +2,8 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from controllers.console.workspace.models import ( DefaultModelApi, @@ -14,8 +16,6 @@ from controllers.console.workspace.models import ( ModelProviderModelParameterRuleApi, ModelProviderModelValidateApi, ) -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError def unwrap(func): diff --git a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py index eb19243225..ce5fd1c466 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_plugin.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_plugin.py @@ -90,8 +90,8 @@ class TestPluginListLatestVersionsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDebuggingKeyApi: @@ -120,8 +120,8 @@ class TestPluginDebuggingKeyApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginListApi: @@ -202,8 +202,9 @@ class TestPluginUploadFromPkgApi: patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_PACKAGE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_pkg") as upload_pkg_mock, ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: method(api) + assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_pkg_mock.assert_not_called() @@ -365,8 +366,8 @@ class TestPluginListInstallationsFromIdsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromGithubApi: @@ -401,8 +402,8 @@ class TestPluginUploadFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUploadFromBundleApi: @@ -449,8 +450,9 @@ class TestPluginUploadFromBundleApi: patch("controllers.console.workspace.plugin.dify_config.PLUGIN_MAX_BUNDLE_SIZE", 0), patch("controllers.console.workspace.plugin.PluginService.upload_bundle") as upload_bundle_mock, ): - with pytest.raises(ValueError): + with pytest.raises(ValueError) as exc_info: method(api) + assert "File size exceeds the maximum allowed size" in str(exc_info.value) upload_bundle_mock.assert_not_called() @@ -495,8 +497,8 @@ class TestPluginInstallFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginInstallFromMarketplaceApi: @@ -532,8 +534,8 @@ class TestPluginInstallFromMarketplaceApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchMarketplacePkgApi: @@ -562,8 +564,8 @@ class TestPluginFetchMarketplacePkgApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchManifestApi: @@ -595,8 +597,8 @@ class TestPluginFetchManifestApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchInstallTasksApi: @@ -625,8 +627,8 @@ class TestPluginFetchInstallTasksApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchInstallTaskApi: @@ -655,8 +657,8 @@ class TestPluginFetchInstallTaskApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "t") + result = method(api, "t") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteInstallTaskApi: @@ -685,8 +687,8 @@ class TestPluginDeleteInstallTaskApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "t") + result = method(api, "t") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteAllInstallTaskItemsApi: @@ -717,8 +719,8 @@ class TestPluginDeleteAllInstallTaskItemsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginDeleteInstallTaskItemApi: @@ -747,8 +749,8 @@ class TestPluginDeleteInstallTaskItemApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api, "task1", "item1") + result = method(api, "task1", "item1") + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromMarketplaceApi: @@ -790,8 +792,8 @@ class TestPluginUpgradeFromMarketplaceApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginUpgradeFromGithubApi: @@ -839,8 +841,8 @@ class TestPluginUpgradeFromGithubApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: @@ -894,8 +896,8 @@ class TestPluginFetchDynamicSelectOptionsWithCredentialsApi: side_effect=PluginDaemonClientSideError("error"), ), ): - with pytest.raises(ValueError): - method(api) + result = method(api) + assert result == ({"code": "plugin_error", "message": "error"}, 400) class TestPluginChangePreferencesApi: diff --git a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py index 94c3019d5e..44feacf2ad 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_tool_providers.py @@ -4,7 +4,7 @@ from __future__ import annotations import builtins import importlib -from contextlib import contextmanager +from contextlib import ExitStack, contextmanager from types import ModuleType, SimpleNamespace from unittest.mock import MagicMock, patch @@ -18,7 +18,6 @@ if not hasattr(builtins, "MethodView"): _CONTROLLER_MODULE: ModuleType | None = None _WRAPS_MODULE: ModuleType | None = None -_CONTROLLER_PATCHERS: list[patch] = [] @contextmanager @@ -37,6 +36,14 @@ def app() -> Flask: @pytest.fixture def controller_module(monkeypatch: pytest.MonkeyPatch): + """ + Import the controller with auth decorators neutralized only during import. + + The imported view classes retain those no-op decorators after import, so we + can restore the original globals immediately and avoid leaking auth patches + into unrelated tests such as libs.login unit coverage. + """ + module_name = "controllers.console.workspace.tool_providers" global _CONTROLLER_MODULE if _CONTROLLER_MODULE is None: @@ -51,13 +58,12 @@ def controller_module(monkeypatch: pytest.MonkeyPatch): ("controllers.console.wraps.is_admin_or_owner_required", _noop), ("controllers.console.wraps.enterprise_license_required", _noop), ] - for target, value in patch_targets: - patcher = patch(target, value) - patcher.start() - _CONTROLLER_PATCHERS.append(patcher) monkeypatch.setenv("DIFY_SETUP_READY", "true") - with _mock_db(): - _CONTROLLER_MODULE = importlib.import_module(module_name) + with ExitStack() as stack: + for target, value in patch_targets: + stack.enter_context(patch(target, value)) + with _mock_db(): + _CONTROLLER_MODULE = importlib.import_module(module_name) module = _CONTROLLER_MODULE monkeypatch.setattr(module, "jsonable_encoder", lambda payload: payload) diff --git a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py index f5ebe0b534..b2d13dbbdf 100644 --- a/api/tests/unit_tests/controllers/console/workspace/test_workspace.py +++ b/api/tests/unit_tests/controllers/console/workspace/test_workspace.py @@ -1,4 +1,3 @@ -from datetime import datetime from io import BytesIO from unittest.mock import MagicMock, patch @@ -26,6 +25,7 @@ from controllers.console.workspace.workspace import ( WorkspacePermissionApi, ) from enums.cloud_plan import CloudPlan +from libs.datetime_utils import naive_utc_now from models.account import TenantStatus @@ -44,13 +44,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) with ( @@ -97,13 +97,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features_t2 = MagicMock() @@ -152,13 +152,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features = MagicMock() @@ -204,7 +204,7 @@ class TestTenantListApi: id="t1", name="Tenant", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) features = MagicMock() @@ -243,13 +243,13 @@ class TestTenantListApi: id="t1", name="Tenant 1", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) tenant2 = MagicMock( id="t2", name="Tenant 2", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) with ( @@ -305,7 +305,7 @@ class TestWorkspaceListApi: api = WorkspaceListApi() method = unwrap(api.get) - tenant = MagicMock(id="t1", name="T", status="active", created_at=datetime.utcnow()) + tenant = MagicMock(id="t1", name="T", status="active", created_at=naive_utc_now()) paginate_result = MagicMock( items=[tenant], @@ -331,7 +331,7 @@ class TestWorkspaceListApi: id="t1", name="T", status="active", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), ) paginate_result = MagicMock( diff --git a/api/tests/unit_tests/controllers/files/test_tool_files.py b/api/tests/unit_tests/controllers/files/test_tool_files.py index e5df7a1eea..edb91c3f26 100644 --- a/api/tests/unit_tests/controllers/files/test_tool_files.py +++ b/api/tests/unit_tests/controllers/files/test_tool_files.py @@ -18,10 +18,10 @@ def fake_request(args: dict): class DummyToolFile: - def __init__(self, mimetype="text/plain", size=10, name="tool.txt"): - self.mimetype = mimetype + def __init__(self, mime_type="text/plain", size=10, filename="tool.txt"): + self.mime_type = mime_type self.size = size - self.name = name + self.filename = filename @pytest.fixture(autouse=True) @@ -87,8 +87,8 @@ class TestToolFileApi: stream = iter([b"data"]) tool_file = DummyToolFile( - mimetype="application/pdf", - name="doc.pdf", + mime_type="application/pdf", + filename="doc.pdf", ) mock_tool_file_manager.return_value.get_file_generator_by_tool_file_id.return_value = ( diff --git a/api/tests/unit_tests/controllers/inner_api/app/__init__.py b/api/tests/unit_tests/controllers/inner_api/app/__init__.py new file mode 100644 index 0000000000..8b13789179 --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/app/__init__.py @@ -0,0 +1 @@ + diff --git a/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py new file mode 100644 index 0000000000..4a5f91cc5d --- /dev/null +++ b/api/tests/unit_tests/controllers/inner_api/app/test_dsl.py @@ -0,0 +1,245 @@ +"""Unit tests for inner_api app DSL import/export endpoints. + +Tests Pydantic model validation, endpoint handler logic, and the +_get_active_account helper. Auth/setup decorators are tested separately +in test_auth_wraps.py; handler tests use inspect.unwrap() to bypass them. +""" + +import inspect +from unittest.mock import MagicMock, patch + +import pytest +from flask import Flask +from pydantic import ValidationError + +from controllers.inner_api.app.dsl import ( + EnterpriseAppDSLExport, + EnterpriseAppDSLImport, + InnerAppDSLImportPayload, + _get_active_account, +) +from services.app_dsl_service import ImportStatus + + +class TestInnerAppDSLImportPayload: + """Test InnerAppDSLImportPayload Pydantic model validation.""" + + def test_valid_payload_all_fields(self): + data = { + "yaml_content": "version: 0.6.0\nkind: app\n", + "creator_email": "user@example.com", + "name": "My App", + "description": "A test app", + } + payload = InnerAppDSLImportPayload.model_validate(data) + assert payload.yaml_content == data["yaml_content"] + assert payload.creator_email == "user@example.com" + assert payload.name == "My App" + assert payload.description == "A test app" + + def test_valid_payload_optional_fields_omitted(self): + data = { + "yaml_content": "version: 0.6.0\n", + "creator_email": "user@example.com", + } + payload = InnerAppDSLImportPayload.model_validate(data) + assert payload.name is None + assert payload.description is None + + def test_missing_yaml_content_fails(self): + with pytest.raises(ValidationError) as exc_info: + InnerAppDSLImportPayload.model_validate({"creator_email": "a@b.com"}) + assert "yaml_content" in str(exc_info.value) + + def test_missing_creator_email_fails(self): + with pytest.raises(ValidationError) as exc_info: + InnerAppDSLImportPayload.model_validate({"yaml_content": "test"}) + assert "creator_email" in str(exc_info.value) + + +class TestGetActiveAccount: + """Test the _get_active_account helper function.""" + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_active_account(self, mock_db): + mock_account = MagicMock() + mock_account.status = "active" + mock_db.session.scalar.return_value = mock_account + + result = _get_active_account("user@example.com") + + assert result is mock_account + mock_db.session.scalar.assert_called_once() + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_none_for_inactive_account(self, mock_db): + mock_account = MagicMock() + mock_account.status = "banned" + mock_db.session.scalar.return_value = mock_account + + result = _get_active_account("banned@example.com") + + assert result is None + + @patch("controllers.inner_api.app.dsl.db") + def test_returns_none_for_nonexistent_email(self, mock_db): + mock_db.session.scalar.return_value = None + + result = _get_active_account("missing@example.com") + + assert result is None + + +class TestEnterpriseAppDSLImport: + """Test EnterpriseAppDSLImport endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseAppDSLImport() + + @pytest.fixture + def _mock_import_deps(self): + """Patch db, Session, and AppDslService for import handler tests.""" + with ( + patch("controllers.inner_api.app.dsl.db"), + patch("controllers.inner_api.app.dsl.Session") as mock_session, + patch("controllers.inner_api.app.dsl.AppDslService") as mock_dsl_cls, + ): + mock_session.return_value.__enter__ = MagicMock(return_value=MagicMock()) + mock_session.return_value.__exit__ = MagicMock(return_value=False) + self._mock_dsl = MagicMock() + mock_dsl_cls.return_value = self._mock_dsl + yield + + def _make_import_result(self, status: ImportStatus, **kwargs) -> "Import": + from services.app_dsl_service import Import + + result = Import( + id="import-id", + status=status, + app_id=kwargs.get("app_id", "app-123"), + app_mode=kwargs.get("app_mode", "workflow"), + ) + return result + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_success_returns_200(self, mock_get_account, api_instance, app: Flask): + mock_account = MagicMock() + mock_get_account.return_value = mock_account + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.COMPLETED) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = { + "yaml_content": "version: 0.6.0\n", + "creator_email": "user@example.com", + } + result = unwrapped(api_instance, workspace_id="ws-123") + + body, status_code = result + assert status_code == 200 + assert body["status"] == "completed" + mock_account.set_tenant_id.assert_called_once_with("ws-123") + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_pending_returns_202(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = MagicMock() + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.PENDING) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"} + body, status_code = unwrapped(api_instance, workspace_id="ws-123") + + assert status_code == 202 + assert body["status"] == "pending" + + @pytest.mark.usefixtures("_mock_import_deps") + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_failed_returns_400(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = MagicMock() + self._mock_dsl.import_app.return_value = self._make_import_result(ImportStatus.FAILED) + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "u@e.com"} + body, status_code = unwrapped(api_instance, workspace_id="ws-123") + + assert status_code == 400 + assert body["status"] == "failed" + + @patch("controllers.inner_api.app.dsl._get_active_account") + def test_import_account_not_found_returns_404(self, mock_get_account, api_instance, app: Flask): + mock_get_account.return_value = None + + unwrapped = inspect.unwrap(api_instance.post) + with app.test_request_context(): + with patch("controllers.inner_api.app.dsl.inner_api_ns") as mock_ns: + mock_ns.payload = {"yaml_content": "test", "creator_email": "missing@e.com"} + result = unwrapped(api_instance, workspace_id="ws-123") + + body, status_code = result + assert status_code == 404 + assert "missing@e.com" in body["message"] + + +class TestEnterpriseAppDSLExport: + """Test EnterpriseAppDSLExport endpoint handler logic. + + Uses inspect.unwrap() to bypass auth/setup decorators. + """ + + @pytest.fixture + def api_instance(self): + return EnterpriseAppDSLExport() + + @patch("controllers.inner_api.app.dsl.AppDslService") + @patch("controllers.inner_api.app.dsl.db") + def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask): + mock_app = MagicMock() + mock_db.session.get.return_value = mock_app + mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n" + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=false"): + result = unwrapped(api_instance, app_id="app-123") + + body, status_code = result + assert status_code == 200 + assert body["data"] == "version: 0.6.0\nkind: app\n" + mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=False) + + @patch("controllers.inner_api.app.dsl.AppDslService") + @patch("controllers.inner_api.app.dsl.db") + def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask): + mock_app = MagicMock() + mock_db.session.get.return_value = mock_app + mock_dsl_cls.export_dsl.return_value = "yaml-data" + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=true"): + result = unwrapped(api_instance, app_id="app-123") + + body, status_code = result + assert status_code == 200 + mock_dsl_cls.export_dsl.assert_called_once_with(app_model=mock_app, include_secret=True) + + @patch("controllers.inner_api.app.dsl.db") + def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask): + mock_db.session.get.return_value = None + + unwrapped = inspect.unwrap(api_instance.get) + with app.test_request_context("?include_secret=false"): + result = unwrapped(api_instance, app_id="nonexistent") + + body, status_code = result + assert status_code == 404 + assert "app not found" in body["message"] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_app.py b/api/tests/unit_tests/controllers/service_api/app/test_app.py index f8e9cf9b80..1507bf7a5f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_app.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_app.py @@ -65,7 +65,7 @@ class TestAppParameterApi: mock_tenant.status = "normal" # Mock DB queries for app and tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -112,7 +112,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -153,7 +153,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -192,7 +192,7 @@ class TestAppParameterApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -255,7 +255,7 @@ class TestAppMetaApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -323,7 +323,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -380,7 +380,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -426,7 +426,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] @@ -478,7 +478,7 @@ class TestAppInfoApi: mock_tenant = Mock() mock_tenant.status = "normal" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_db.session.get.side_effect = [ mock_app, mock_tenant, ] diff --git a/api/tests/unit_tests/controllers/service_api/app/test_audio.py b/api/tests/unit_tests/controllers/service_api/app/test_audio.py index 1923ab7fa7..5a8cb4619f 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_audio.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_audio.py @@ -13,6 +13,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from werkzeug.datastructures import FileStorage from werkzeug.exceptions import InternalServerError @@ -29,7 +30,6 @@ from controllers.service_api.app.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from services.audio_service import AudioService from services.errors.app_model_config import AppModelConfigBrokenError from services.errors.audio import ( diff --git a/api/tests/unit_tests/controllers/service_api/app/test_completion.py b/api/tests/unit_tests/controllers/service_api/app/test_completion.py index 4e4482f704..57681d8f5b 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_completion.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_completion.py @@ -16,6 +16,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeError from pydantic import ValidationError from werkzeug.exceptions import BadRequest, NotFound @@ -34,7 +35,6 @@ from controllers.service_api.app.error import ( NotChatAppError, ) from core.errors.error import QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from models.model import App, AppMode, EndUser from services.app_generate_service import AppGenerateService from services.app_task_service import AppTaskService diff --git a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py index 1bdcd0f1a3..d83c22f2cf 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_file_preview.py @@ -79,10 +79,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -100,8 +103,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile not found - mock_db.session.query.return_value.where.return_value.first.return_value = None + # Mock MessageFile not found via scalar() + mock_db.session.scalar.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -115,8 +118,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile found but Message not owned by app - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock MessageFile found but Message not owned by app via scalar() + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found None, # Message query - not found (access denied) ] @@ -133,12 +136,13 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock MessageFile and Message found but UploadFile not found - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query - found mock_message, # Message query - found - None, # UploadFile query - not found ] + # Mock get() for UploadFile - not found + mock_db.session.get.return_value = None # Execute and assert exception with pytest.raises(FileNotFoundError) as exc_info: @@ -161,10 +165,13 @@ class TestFilePreviewApi: mock_message_file.message_id = mock_message.id with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -262,10 +269,13 @@ class TestFilePreviewApi: mock_storage.load.return_value = mock_generator with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -301,10 +311,13 @@ class TestFilePreviewApi: mock_storage.load.side_effect = Exception("Storage error") with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database queries for validation - mock_db.session.query.return_value.where.return_value.first.side_effect = [ + # Mock scalar() for MessageFile and Message queries + mock_db.session.scalar.side_effect = [ mock_message_file, # MessageFile query mock_message, # Message query + ] + # Mock get() for UploadFile and App PK lookups + mock_db.session.get.side_effect = [ mock_upload_file, # UploadFile query mock_app, # App query for tenant validation ] @@ -327,8 +340,8 @@ class TestFilePreviewApi: app_id = str(uuid.uuid4()) with patch("controllers.service_api.app.file_preview.db") as mock_db: - # Mock database query to raise unexpected exception - mock_db.session.query.side_effect = Exception("Unexpected database error") + # Mock database scalar to raise unexpected exception + mock_db.session.scalar.side_effect = Exception("Unexpected database error") # Execute and assert exception with pytest.raises(FileAccessDeniedError) as exc_info: diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py index 4eada73b82..b1f036c6f3 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow.py @@ -19,6 +19,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.enums import WorkflowExecutionStatus from werkzeug.exceptions import BadRequest, NotFound from controllers.service_api.app.error import NotWorkflowAppError @@ -35,7 +36,6 @@ from controllers.service_api.app.workflow import ( WorkflowTaskStopApi, ) from controllers.web.error import InvokeRateLimitError as InvokeRateLimitHttpError -from dify_graph.enums import WorkflowExecutionStatus from models.model import App, AppMode from services.app_generate_service import AppGenerateService from services.errors.app import IsDraftWorkflowError, WorkflowNotFoundError @@ -315,7 +315,7 @@ class TestWorkflowStopMechanism: def test_graph_engine_manager_has_send_stop_command(self): """Test GraphEngineManager has send_stop_command method.""" - from dify_graph.graph_engine.manager import GraphEngineManager + from graphon.graph_engine.manager import GraphEngineManager assert hasattr(GraphEngineManager, "send_stop_command") diff --git a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py index 9e95f45a0a..4b8e3a738c 100644 --- a/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py +++ b/api/tests/unit_tests/controllers/service_api/app/test_workflow_fields.py @@ -1,7 +1,8 @@ from types import SimpleNamespace +from graphon.enums import WorkflowExecutionStatus + from controllers.service_api.app.workflow import WorkflowRunOutputsField, WorkflowRunStatusField -from dify_graph.enums import WorkflowExecutionStatus def test_workflow_run_status_field_with_enum() -> None: diff --git a/api/tests/unit_tests/controllers/service_api/conftest.py b/api/tests/unit_tests/controllers/service_api/conftest.py index 01d2d1e7c0..eddba5a517 100644 --- a/api/tests/unit_tests/controllers/service_api/conftest.py +++ b/api/tests/unit_tests/controllers/service_api/conftest.py @@ -119,11 +119,8 @@ class AuthenticationMocker: @staticmethod def setup_db_queries(mock_db, mock_app, mock_tenant, mock_account=None): - """Configure mock_db to return app and tenant in sequence.""" - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - ] + """Configure mock_db to return app and tenant via session.get().""" + mock_db.session.get.side_effect = [mock_app, mock_tenant] if mock_account: mock_ta = Mock() @@ -136,11 +133,9 @@ class AuthenticationMocker: mock_ta = Mock() mock_ta.account_id = mock_account.id - mock_query = mock_db.session.query.return_value - target_mock = mock_query.where.return_value.where.return_value.where.return_value.where.return_value - target_mock.one_or_none.return_value = (mock_tenant, mock_ta) + mock_db.session.execute.return_value.one_or_none.return_value = (mock_tenant, mock_ta) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + mock_db.session.get.return_value = mock_account @pytest.fixture diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py index 8fe41cd19f..910d781cd0 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset.py @@ -942,11 +942,11 @@ class TestDatasetListApiGet: """Test suite for DatasetListApi.get() endpoint. ``get`` has no billing decorators but calls ``current_user``, - ``DatasetService``, ``ProviderManager``, and ``marshal``. + ``DatasetService``, ``create_plugin_provider_manager``, and ``marshal``. """ @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_list_datasets_success( @@ -1044,12 +1044,12 @@ class TestDatasetApiGet: """Test suite for DatasetApi.get() endpoint. ``get`` has no billing decorators but calls ``DatasetService``, - ``ProviderManager``, ``marshal``, and ``current_user``. + ``create_plugin_provider_manager``, ``marshal``, and ``current_user``. """ @patch("controllers.service_api.dataset.dataset.DatasetPermissionService") @patch("controllers.service_api.dataset.dataset.marshal") - @patch("controllers.service_api.dataset.dataset.ProviderManager") + @patch("controllers.service_api.dataset.dataset.create_plugin_provider_manager") @patch("controllers.service_api.dataset.dataset.current_user") @patch("controllers.service_api.dataset.dataset.DatasetService") def test_get_dataset_success( diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py index 73a87761d5..7f5d6b0839 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_dataset_segment.py @@ -788,7 +788,7 @@ class TestSegmentApiGet: """Test successful segment list retrieval.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_seg_svc.get_segments.return_value = ([mock_segment], 1) mock_marshal.return_value = [{"id": mock_segment.id}] @@ -813,7 +813,7 @@ class TestSegmentApiGet: """Test 404 when dataset not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -833,7 +833,7 @@ class TestSegmentApiGet: """Test 404 when document not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None # Act & Assert @@ -899,7 +899,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -950,7 +950,7 @@ class TestSegmentApiPost: mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -992,7 +992,7 @@ class TestSegmentApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "indexing" # Not completed @@ -1043,7 +1043,7 @@ class TestDatasetSegmentApiDelete: """Test successful segment deletion.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock() @@ -1087,7 +1087,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when segment not found.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc.indexing_status = "completed" @@ -1129,7 +1129,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when dataset not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1163,7 +1163,7 @@ class TestDatasetSegmentApiDelete: """Test 404 when document not found for delete.""" # Arrange mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1233,7 +1233,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = mock_segment @@ -1280,7 +1280,7 @@ class TestDatasetSegmentApiUpdate: """Test 404 when dataset not found for update.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1321,7 +1321,7 @@ class TestDatasetSegmentApiUpdate: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1370,7 +1370,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test successful single segment retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc = Mock(doc_form=IndexStructureType.PARAGRAPH_INDEX) mock_doc_svc.get_document.return_value = mock_doc @@ -1405,7 +1405,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id", @@ -1436,7 +1436,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = None @@ -1471,7 +1471,7 @@ class TestDatasetSegmentApiGetSingle: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset_svc.check_dataset_model_setting.return_value = None mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1515,7 +1515,7 @@ class TestChildChunkApiGet: ): """Test successful child chunk list retrieval.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() @@ -1554,7 +1554,7 @@ class TestChildChunkApiGet: ): """Test 404 when dataset not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1583,7 +1583,7 @@ class TestChildChunkApiGet: ): """Test 404 when document not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None with app.test_request_context( @@ -1615,7 +1615,7 @@ class TestChildChunkApiGet: ): """Test 404 when segment not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1676,7 +1676,7 @@ class TestChildChunkApiPost: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) mock_dataset.indexing_technique = "economy" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = Mock() mock_child = Mock() @@ -1717,7 +1717,7 @@ class TestChildChunkApiPost: """Test 404 when dataset not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with app.test_request_context( f"/datasets/{mock_dataset.id}/documents/doc-id/segments/seg-id/child_chunks", @@ -1755,7 +1755,7 @@ class TestChildChunkApiPost: """Test 404 when segment not found.""" self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() mock_seg_svc.get_segment_by_id.return_value = None @@ -1808,7 +1808,7 @@ class TestDatasetChildChunkApiDelete: ): """Test successful child chunk deletion.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc = Mock() mock_doc_svc.get_document.return_value = mock_doc @@ -1858,7 +1858,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk not found.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1899,7 +1899,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when segment does not belong to the document.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) @@ -1939,7 +1939,7 @@ class TestDatasetChildChunkApiDelete: ): """Test 404 when child chunk does not belong to the segment.""" mock_account_fn.return_value = (Mock(), mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = Mock() segment_id = str(uuid.uuid4()) diff --git a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py index 7f77e61ee4..12d5e7345d 100644 --- a/api/tests/unit_tests/controllers/service_api/dataset/test_document.py +++ b/api/tests/unit_tests/controllers/service_api/dataset/test_document.py @@ -717,7 +717,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = False @@ -746,7 +746,7 @@ class TestDocumentApiDelete: document_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = None @@ -767,7 +767,7 @@ class TestDocumentApiDelete: dataset_id = str(uuid.uuid4()) mock_dataset = Mock() mock_dataset.id = dataset_id - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_document.return_value = mock_document mock_doc_svc.check_archived.return_value = True @@ -788,7 +788,7 @@ class TestDocumentApiDelete: # Arrange dataset_id = str(uuid.uuid4()) document_id = str(uuid.uuid4()) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -809,7 +809,7 @@ class TestDocumentListApi: def test_list_documents_success(self, mock_db, mock_doc_svc, mock_marshal, app, mock_tenant, mock_dataset): """Test successful document list retrieval.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_pagination = Mock() mock_pagination.items = [Mock(), Mock()] @@ -838,7 +838,7 @@ class TestDocumentListApi: def test_list_documents_dataset_not_found(self, mock_db, app, mock_tenant, mock_dataset): """Test 404 when dataset not found.""" # Arrange - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -860,8 +860,6 @@ class TestDocumentIndexingStatusApi: """Test successful indexing status retrieval.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset - mock_doc = Mock() mock_doc.id = str(uuid.uuid4()) mock_doc.is_paused = False @@ -877,8 +875,8 @@ class TestDocumentIndexingStatusApi: mock_doc_svc.get_batch_documents.return_value = [mock_doc] - # Mock segment count queries - mock_db.session.query.return_value.where.return_value.where.return_value.count.return_value = 5 + # scalar() called 3 times: dataset lookup, completed_segments count, total_segments count + mock_db.session.scalar.side_effect = [mock_dataset, 5, 5] mock_marshal.return_value = {"id": mock_doc.id, "indexing_status": "completed"} # Act @@ -898,7 +896,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when dataset not found.""" # Arrange batch_id = "batch_123" - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -915,7 +913,7 @@ class TestDocumentIndexingStatusApi: """Test 404 when no documents found for batch.""" # Arrange batch_id = "batch_empty" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_doc_svc.get_batch_documents.return_value = [] # Act & Assert @@ -986,7 +984,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_dataset.indexing_technique = "economy" mock_current_user.id = str(uuid.uuid4()) @@ -1035,7 +1033,7 @@ class TestDocumentAddByTextApi: # Arrange — neutralise billing decorators self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context( @@ -1064,7 +1062,7 @@ class TestDocumentAddByTextApi: self._setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset # Act & Assert with app.test_request_context( @@ -1150,7 +1148,7 @@ class TestDocumentUpdateByTextApiPost: _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.indexing_technique = "economy" mock_dataset.latest_process_rule = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() @@ -1193,7 +1191,7 @@ class TestDocumentUpdateByTextApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None doc_id = str(uuid.uuid4()) with app.test_request_context( @@ -1232,7 +1230,7 @@ class TestDocumentAddByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1263,7 +1261,7 @@ class TestDocumentAddByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1298,7 +1296,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = "economy" mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset with app.test_request_context( f"/datasets/{mock_dataset.id}/document/create_by_file", @@ -1328,7 +1326,7 @@ class TestDocumentAddByFileApiPost: mock_dataset.provider = "vendor" mock_dataset.indexing_technique = None mock_dataset.chunk_structure = None - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1366,7 +1364,7 @@ class TestDocumentUpdateByFileApiPost: ): """Test ValueError when dataset not found.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None from io import BytesIO @@ -1402,7 +1400,7 @@ class TestDocumentUpdateByFileApiPost: """Test ValueError when dataset is external.""" _setup_billing_mocks(mock_validate_token, mock_feature_svc, mock_tenant.id) mock_dataset.provider = "external" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset from io import BytesIO @@ -1450,7 +1448,7 @@ class TestDocumentUpdateByFileApiPost: mock_dataset.chunk_structure = None mock_dataset.latest_process_rule = Mock() mock_dataset.created_by_account = Mock() - mock_db.session.query.return_value.where.return_value.first.return_value = mock_dataset + mock_db.session.scalar.return_value = mock_dataset mock_current_user.id = "user-1" mock_upload = Mock() diff --git a/api/tests/unit_tests/controllers/service_api/test_site.py b/api/tests/unit_tests/controllers/service_api/test_site.py index b58caf3be1..c0b40d070a 100644 --- a/api/tests/unit_tests/controllers/service_api/test_site.py +++ b/api/tests/unit_tests/controllers/service_api/test_site.py @@ -88,7 +88,7 @@ class TestAppSiteApi: mock_app_model.tenant = mock_tenant # Mock wraps.db for authentication - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -98,7 +98,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site.db for site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -109,7 +109,7 @@ class TestAppSiteApi: assert response["title"] == "Test Site" assert response["icon"] == "icon-url" assert response["description"] == "Site description" - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() @patch("controllers.service_api.wraps.user_logged_in") @patch("controllers.service_api.app.site.db") @@ -140,7 +140,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -150,7 +150,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query to return None - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None # Act & Assert with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -187,7 +187,7 @@ class TestAppSiteApi: mock_tenant = Mock() mock_tenant.status = TenantStatus.NORMAL - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -197,7 +197,7 @@ class TestAppSiteApi: setup_mock_tenant_account_query(mock_wraps_db, mock_tenant, mock_account) # Mock site query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Set tenant status to archived AFTER authentication mock_app_model.tenant.status = TenantStatus.ARCHIVE @@ -230,7 +230,7 @@ class TestAppSiteApi: mock_tenant.status = TenantStatus.NORMAL mock_app_model.tenant = mock_tenant - mock_wraps_db.session.query.return_value.where.return_value.first.side_effect = [ + mock_wraps_db.session.get.side_effect = [ mock_app_model, mock_tenant, ] @@ -258,7 +258,7 @@ class TestAppSiteApi: mock_site.icon_type = "image" mock_site.created_at = "2024-01-01T00:00:00" mock_site.updated_at = "2024-01-01T00:00:00" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_site + mock_db.session.scalar.return_value = mock_site # Act with app.test_request_context("/site", method="GET", headers={"Authorization": "Bearer test_token"}): @@ -267,4 +267,4 @@ class TestAppSiteApi: # Assert # The query was executed successfully (site returned), which validates the correct query was made - mock_db.session.query.assert_called_once_with(Site) + mock_db.session.scalar.assert_called_once() diff --git a/api/tests/unit_tests/controllers/service_api/test_wraps.py b/api/tests/unit_tests/controllers/service_api/test_wraps.py index 9c2d075f41..a2008e024b 100644 --- a/api/tests/unit_tests/controllers/service_api/test_wraps.py +++ b/api/tests/unit_tests/controllers/service_api/test_wraps.py @@ -144,14 +144,10 @@ class TestValidateAppToken: mock_ta = Mock() mock_ta.account_id = mock_account.id - # Use side_effect to return app first, then tenant - mock_db.session.query.return_value.where.return_value.first.side_effect = [ - mock_app, - mock_tenant, - mock_account, - ] + # Use side_effect to return app first, then tenant via session.get() + mock_db.session.get.side_effect = [mock_app, mock_tenant] - # Mock the tenant owner query + # Mock the tenant owner query (execute(select(...)).one_or_none()) setup_mock_tenant_account_query(mock_db, mock_tenant, mock_ta) @validate_app_token @@ -175,7 +171,7 @@ class TestValidateAppToken: mock_api_token.app_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None @validate_app_token def protected_view(**kwargs): @@ -198,7 +194,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "abnormal" - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -222,7 +218,7 @@ class TestValidateAppToken: mock_app = Mock() mock_app.status = "normal" mock_app.enable_api = False - mock_db.session.query.return_value.where.return_value.first.return_value = mock_app + mock_db.session.get.return_value = mock_app @validate_app_token def protected_view(**kwargs): @@ -474,11 +470,11 @@ class TestValidateDatasetToken: mock_account.id = mock_ta.account_id mock_account.current_tenant = mock_tenant - # Mock the tenant account join query + # Mock the tenant account join query (execute(select(...)).one_or_none()) setup_mock_dataset_tenant_query(mock_db, mock_tenant, mock_ta) - # Mock the account query - mock_db.session.query.return_value.where.return_value.first.return_value = mock_account + # Mock the account lookup via session.get() + mock_db.session.get.return_value = mock_account @validate_dataset_token def protected_view(tenant_id): @@ -501,7 +497,7 @@ class TestValidateDatasetToken: mock_api_token.tenant_id = str(uuid.uuid4()) mock_validate_token.return_value = mock_api_token - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None @validate_dataset_token def protected_view(dataset_id=None, **kwargs): diff --git a/api/tests/unit_tests/controllers/web/test_audio.py b/api/tests/unit_tests/controllers/web/test_audio.py index 01f34345aa..cbfc8fa613 100644 --- a/api/tests/unit_tests/controllers/web/test_audio.py +++ b/api/tests/unit_tests/controllers/web/test_audio.py @@ -8,6 +8,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.audio import AudioApi, TextApi from controllers.web.error import ( @@ -21,7 +22,6 @@ from controllers.web.error import ( UnsupportedAudioTypeError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError from services.errors.audio import ( AudioTooLargeServiceError, NoAudioUploadedServiceError, diff --git a/api/tests/unit_tests/controllers/web/test_completion.py b/api/tests/unit_tests/controllers/web/test_completion.py index e88bcf2ae6..49039d03fe 100644 --- a/api/tests/unit_tests/controllers/web/test_completion.py +++ b/api/tests/unit_tests/controllers/web/test_completion.py @@ -7,6 +7,7 @@ from unittest.mock import MagicMock, patch import pytest from flask import Flask +from graphon.model_runtime.errors.invoke import InvokeError from controllers.web.completion import ChatApi, ChatStopApi, CompletionApi, CompletionStopApi from controllers.web.error import ( @@ -18,7 +19,6 @@ from controllers.web.error import ( ProviderQuotaExceededError, ) from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeError def _completion_app() -> SimpleNamespace: diff --git a/api/tests/unit_tests/core/agent/test_base_agent_runner.py b/api/tests/unit_tests/core/agent/test_base_agent_runner.py index 683cc0e36f..db4b293b16 100644 --- a/api/tests/unit_tests/core/agent/test_base_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_base_agent_runner.py @@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool: class TestBaseAgentRunnerInit: def test_init_sets_stream_tool_call_and_files(self, mocker): session = mocker.MagicMock() - session.query.return_value.where.return_value.count.return_value = 2 + session.scalar.return_value = 2 mocker.patch.object(module.db, "session", session) mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[]) diff --git a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py index f6d1edbaf0..bc7aea0ef9 100644 --- a/api/tests/unit_tests/core/agent/test_cot_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_agent_runner.py @@ -2,11 +2,11 @@ import json from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.agent.cot_agent_runner import CotAgentRunner from core.agent.entities import AgentScratchpadUnit from core.agent.errors import AgentMaxIterationError -from dify_graph.model_runtime.entities.llm_entities import LLMUsage class DummyRunner(CotAgentRunner): @@ -387,7 +387,7 @@ class TestRun: runner.update_prompt_message_tool.assert_called_once() def test_historic_with_assistant_and_tool_calls(self, runner): - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ToolPromptMessage assistant = AssistantPromptMessage(content="thinking") assistant.tool_calls = [MagicMock(function=MagicMock(name="tool", arguments='{"a":1}'))] @@ -400,7 +400,7 @@ class TestRun: assert isinstance(result, list) def test_historic_final_flush_branch(self, runner): - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage assistant = AssistantPromptMessage(content="final") runner.history_prompt_messages = [assistant] @@ -458,7 +458,7 @@ class TestFillInputsEdgeCases: class TestOrganizeHistoricPromptMessagesExtended: def test_user_message_flushes_scratchpad(self, runner, mocker): - from dify_graph.model_runtime.entities.message_entities import UserPromptMessage + from graphon.model_runtime.entities.message_entities import UserPromptMessage user_message = UserPromptMessage(content="Hi") @@ -473,7 +473,7 @@ class TestOrganizeHistoricPromptMessagesExtended: assert result == ["final"] def test_tool_message_without_scratchpad_raises(self, runner): - from dify_graph.model_runtime.entities.message_entities import ToolPromptMessage + from graphon.model_runtime.entities.message_entities import ToolPromptMessage runner.history_prompt_messages = [ToolPromptMessage(content="obs", tool_call_id="1")] diff --git a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py index f9d69d1196..97206019b9 100644 --- a/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_chat_agent_runner.py @@ -1,9 +1,9 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.agent.cot_chat_agent_runner import CotChatAgentRunner -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent from tests.unit_tests.core.agent.conftest import ( DummyAgentConfig, DummyAppConfig, @@ -93,7 +93,7 @@ class TestOrganizeUserQuery: @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") def test_organize_user_query_with_image_file_default_config(self, mock_to_prompt, mock_user_prompt, runner): - from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent mock_content = ImagePromptMessageContent( url="http://test", @@ -118,7 +118,7 @@ class TestOrganizeUserQuery: @patch("core.agent.cot_chat_agent_runner.UserPromptMessage") @patch("core.agent.cot_chat_agent_runner.file_manager.to_prompt_message_content") def test_organize_user_query_with_image_file_high_detail(self, mock_to_prompt, mock_user_prompt, runner): - from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent + from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent mock_content = ImagePromptMessageContent( url="http://test", diff --git a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py index ab822bb57d..defc8b4b64 100644 --- a/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_cot_completion_agent_runner.py @@ -1,15 +1,15 @@ import json import pytest - -from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, UserPromptMessage, ) +from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner + # ----------------------------- # Fixtures # ----------------------------- diff --git a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py index 299c9b31d2..a44a0650eb 100644 --- a/api/tests/unit_tests/core/agent/test_fc_agent_runner.py +++ b/api/tests/unit_tests/core/agent/test_fc_agent_runner.py @@ -3,19 +3,19 @@ from typing import Any from unittest.mock import MagicMock import pytest - -from core.agent.errors import AgentMaxIterationError -from core.agent.fc_agent_runner import FunctionCallAgentRunner -from core.app.apps.base_app_queue_manager import PublishFrom -from core.app.entities.queue_entities import QueueMessageFileEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import ( DocumentPromptMessageContent, ImagePromptMessageContent, TextPromptMessageContent, UserPromptMessage, ) +from core.agent.errors import AgentMaxIterationError +from core.agent.fc_agent_runner import FunctionCallAgentRunner +from core.app.apps.base_app_queue_manager import PublishFrom +from core.app.entities.queue_entities import QueueMessageFileEvent + # ============================== # Dummy Helper Classes # ============================== diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py index aed1651511..5ee66da94a 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_converter.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.app_config.easy_ui_based_app.model_config.converter import ModelConfigConverter from core.entities.model_entities import ModelStatus @@ -10,8 +12,6 @@ from core.errors.error import ( ProviderTokenNotInitError, QuotaExceededError, ) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey class TestModelConfigConverter: @@ -73,7 +73,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) return mock_manager @@ -109,7 +109,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -124,7 +124,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -141,7 +141,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -158,7 +158,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -183,7 +183,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) @@ -200,7 +200,7 @@ class TestModelConfigConverter: mock_manager = MagicMock() mock_manager.get_provider_model_bundle.return_value = mock_provider_bundle mocker.patch( - "core.app.app_config.easy_ui_based_app.model_config.converter.ProviderManager", + "core.app.app_config.easy_ui_based_app.model_config.converter.create_plugin_provider_manager", return_value=mock_manager, ) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py index e2ba276d8e..68bca485bb 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_model_config_manager.py @@ -43,6 +43,17 @@ def valid_config(): class TestModelConfigManager: + @staticmethod + def _patch_model_assembly(mocker, *, provider_entities, model_list): + assembly = MagicMock() + assembly.model_provider_factory.get_providers.return_value = provider_entities + assembly.provider_manager.get_configurations.return_value.get_models.return_value = model_list + mocker.patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) + return assembly + # ========================================================== # convert # ========================================================== @@ -97,11 +108,11 @@ class TestModelConfigManager: # ========================================================== def test_validate_and_set_defaults_success(self, mocker, valid_config, provider_entities, valid_model_list): - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) updated_config, keys = ModelConfigManager.validate_and_set_defaults("tenant1", valid_config) @@ -118,51 +129,39 @@ class TestModelConfigManager: def test_validate_and_set_defaults_missing_provider(self, mocker, provider_entities): config = {"model": {"name": "gpt-4", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) with pytest.raises(ValueError, match="model.provider is required"): ModelConfigManager.validate_and_set_defaults("tenant1", config) def test_validate_and_set_defaults_invalid_provider(self, mocker, provider_entities): config = {"model": {"provider": "invalid/provider", "name": "gpt-4", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) with pytest.raises(ValueError, match="model.provider is required"): ModelConfigManager.validate_and_set_defaults("tenant1", config) def test_validate_and_set_defaults_missing_name(self, mocker, provider_entities): config = {"model": {"provider": "openai/gpt", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) with pytest.raises(ValueError, match="model.name is required"): ModelConfigManager.validate_and_set_defaults("tenant1", config) def test_validate_and_set_defaults_empty_models(self, mocker, provider_entities): config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = [] + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[]) with pytest.raises(ValueError, match="must be in the specified model list"): ModelConfigManager.validate_and_set_defaults("tenant1", config) def test_validate_and_set_defaults_invalid_model_name(self, mocker, provider_entities, valid_model_list): config = {"model": {"provider": "openai/gpt", "name": "invalid", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) with pytest.raises(ValueError, match="must be in the specified model list"): ModelConfigManager.validate_and_set_defaults("tenant1", config) @@ -173,12 +172,7 @@ class TestModelConfigManager: model.model_properties = {} config = {"model": {"provider": "openai/gpt", "name": "gpt-4", "completion_params": {}}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = [model] + self._patch_model_assembly(mocker, provider_entities=provider_entities, model_list=[model]) updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) @@ -186,12 +180,11 @@ class TestModelConfigManager: def test_validate_and_set_defaults_missing_completion_params(self, mocker, provider_entities, valid_model_list): config = {"model": {"provider": "openai/gpt", "name": "gpt-4"}} - - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") - mock_factory.return_value.get_providers.return_value = provider_entities - - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + self._patch_model_assembly( + mocker, + provider_entities=provider_entities, + model_list=valid_model_list, + ) with pytest.raises(ValueError, match="completion_params is required"): ModelConfigManager.validate_and_set_defaults("tenant1", config) @@ -212,16 +205,9 @@ class TestModelConfigManager: # Mock ModelProviderID to return formatted provider mock_provider_id = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderID") mock_provider_id.return_value = "openai/gpt" - - # Mock provider factory - mock_factory = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ModelProviderFactory") provider_entity = MagicMock() provider_entity.provider = "openai/gpt" - mock_factory.return_value.get_providers.return_value = [provider_entity] - - # Mock provider manager - mock_pm = mocker.patch("core.app.app_config.easy_ui_based_app.model_config.manager.ProviderManager") - mock_pm.return_value.get_configurations.return_value.get_models.return_value = valid_model_list + self._patch_model_assembly(mocker, provider_entities=[provider_entity], model_list=valid_model_list) updated_config, _ = ModelConfigManager.validate_and_set_defaults("tenant1", config) diff --git a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py index 5def29b741..e2f3c16335 100644 --- a/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py +++ b/api/tests/unit_tests/core/app/app_config/easy_ui_based_app/test_variables_manager.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntityType from core.app.app_config.easy_ui_based_app.variables.manager import ( BasicVariablesConfigManager, ) -from dify_graph.variables.input_entities import VariableEntityType class TestBasicVariablesConfigManagerConvert: diff --git a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py index de99833aac..8bde9c1f97 100644 --- a/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py +++ b/api/tests/unit_tests/core/app/app_config/features/file_upload/test_manager.py @@ -1,6 +1,7 @@ +from graphon.file import FileTransferMethod, FileUploadConfig, ImageConfig +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent + from core.app.app_config.features.file_upload.manager import FileUploadConfigManager -from dify_graph.file.models import FileTransferMethod, FileUploadConfig, ImageConfig -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent def test_convert_with_vision(): diff --git a/api/tests/unit_tests/core/app/app_config/test_entities.py b/api/tests/unit_tests/core/app/app_config/test_entities.py index eafdf99c16..000f83cd5a 100644 --- a/api/tests/unit_tests/core/app/app_config/test_entities.py +++ b/api/tests/unit_tests/core/app/app_config/test_entities.py @@ -1,10 +1,10 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.app_config.entities import ( DatasetRetrieveConfigEntity, PromptTemplateEntity, ) -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType class TestAppConfigEntities: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py index 441d2fcd17..af5d203f12 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_generator.py @@ -9,10 +9,17 @@ from pydantic import BaseModel, ValidationError from constants import UUID_NIL from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator, _refresh_model +from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator +from core.app.apps.advanced_chat.generate_task_pipeline import ( + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.ops.ops_trace_manager import TraceQueueManager +from libs.datetime_utils import naive_utc_now +from models.enums import MessageStatus from models.model import AppMode @@ -363,8 +370,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-1", tenant_id="tenant", features={"feature": True}, features_dict={}) conversation = SimpleNamespace(id="conv-1", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-1") + message = SimpleNamespace( + id="msg-1", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(commit=MagicMock(), refresh=MagicMock(), close=MagicMock()) captured: dict[str, object] = {} thread_data: dict[str, object] = {} @@ -394,19 +408,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -424,7 +425,7 @@ class TestAdvancedChatAppGeneratorInternals: pause_state_config = SimpleNamespace(session_factory="session-factory", state_owner_user_id="owner") response = generator._generate( - workflow=SimpleNamespace(features={"feature": True}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -444,6 +445,9 @@ class TestAdvancedChatAppGeneratorInternals: db_session.refresh.assert_called_once_with(conversation) db_session.close.assert_called_once() assert captured["draft_var_saver_factory"] == "draft-factory" + assert isinstance(captured["workflow"], WorkflowSnapshot) + assert isinstance(captured["conversation"], ConversationSnapshot) + assert isinstance(captured["message"], MessageSnapshot) def test_generate_internal_flow_with_existing_records_skips_init(self, monkeypatch): generator = AdvancedChatAppGenerator() @@ -464,8 +468,15 @@ class TestAdvancedChatAppGeneratorInternals: workflow_run_id="run-id", ) + workflow = SimpleNamespace(id="wf-2", tenant_id="tenant", features={}, features_dict={}) conversation = SimpleNamespace(id="conv-2", mode=AppMode.ADVANCED_CHAT, override_model_configs=None) - message = SimpleNamespace(id="msg-2") + message = SimpleNamespace( + id="msg-2", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ) db_session = SimpleNamespace(close=MagicMock(), commit=MagicMock(), refresh=MagicMock()) init_records = MagicMock() thread_data: dict[str, object] = {} @@ -491,19 +502,6 @@ class TestAdvancedChatAppGeneratorInternals: thread_data["started"] = True monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.threading.Thread", _Thread) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator._refresh_model", lambda session, model: model) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) monkeypatch.setattr( "core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object(), session=db_session) ) @@ -519,7 +517,7 @@ class TestAdvancedChatAppGeneratorInternals: ) response = generator._generate( - workflow=SimpleNamespace(features={}), + workflow=workflow, user=SimpleNamespace(id="user"), invoke_from=InvokeFrom.WEB_APP, application_generate_entity=application_generate_entity, @@ -940,10 +938,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(GenerateTaskStoppedError): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -981,10 +985,16 @@ class TestAdvancedChatAppGeneratorInternals: with pytest.raises(ValueError, match="other error"): generator._handle_advanced_chat_response( application_generate_entity=application_generate_entity, - workflow=SimpleNamespace(), + workflow=WorkflowSnapshot(id="wf", tenant_id="tenant", features_dict={}), queue_manager=SimpleNamespace(), - conversation=SimpleNamespace(id="conv", mode=AppMode.ADVANCED_CHAT), - message=SimpleNamespace(id="msg"), + conversation=ConversationSnapshot(id="conv", mode=AppMode.ADVANCED_CHAT), + message=MessageSnapshot( + id="msg", + query="hello", + created_at=naive_utc_now(), + status=MessageStatus.NORMAL, + answer="", + ), user=SimpleNamespace(), draft_var_saver_factory=lambda **kwargs: None, stream=False, @@ -992,31 +1002,6 @@ class TestAdvancedChatAppGeneratorInternals: logger_exception.assert_called_once() - def test_refresh_model_returns_detached_model(self, monkeypatch): - source_model = SimpleNamespace(id="source-id") - detached_model = SimpleNamespace(id="source-id", detached=True) - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def get(self, model_type, model_id): - _ = model_type - return detached_model if model_id == "source-id" else None - - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session) - monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object())) - - refreshed = _refresh_model(session=SimpleNamespace(), model=source_model) - - assert refreshed is detached_model - def test_generate_worker_handles_invoke_auth_error(self, monkeypatch): generator = AdvancedChatAppGenerator() generator._dialogue_count = 1 @@ -1053,7 +1038,7 @@ class TestAdvancedChatAppGeneratorInternals: _ = kwargs def run(self): - from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError + from graphon.model_runtime.errors.invoke import InvokeAuthorizationError raise InvokeAuthorizationError("bad key") diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py index 15aceef2c7..061719d15a 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_conversation_variables.py @@ -3,14 +3,27 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 +from graphon.variables import SegmentType from sqlalchemy.orm import Session from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom -from dify_graph.variables import SegmentType from factories import variable_factory from models import ConversationVariable, Workflow +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + class TestAdvancedChatAppRunnerConversationVariables: """Test that AdvancedChatAppRunner correctly handles conversation variables.""" @@ -49,7 +62,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variable (only var1 exists in DB) @@ -200,7 +213,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Mock conversation and message @@ -349,7 +362,7 @@ class TestAdvancedChatAppRunnerConversationVariables: mock_workflow.app_id = app_id mock_workflow.id = workflow_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] # Create existing conversation variables (both exist in DB) diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py index 5792a2f1e2..079df0b4e6 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_app_runner_input_moderation.py @@ -8,6 +8,19 @@ from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, from core.app.entities.queue_entities import QueueStopEvent from core.moderation.base import ModerationError +MINIMAL_GRAPH = { + "nodes": [ + { + "id": "start", + "data": { + "type": "start", + "title": "Start", + }, + } + ], + "edges": [], +} + @pytest.fixture def build_runner(): @@ -30,7 +43,7 @@ def build_runner(): mock_workflow.tenant_id = str(uuid4()) mock_workflow.app_id = app_id mock_workflow.type = "chat" - mock_workflow.graph_dict = {} + mock_workflow.graph_dict = MINIMAL_GRAPH mock_workflow.environment_variables = [] mock_app_config = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py index 5b199e0c52..e9fdeefee4 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.apps.advanced_chat.generate_response_converter import AdvancedChatAppGenerateResponseConverter from core.app.entities.task_entities import ( ChatbotAppBlockingResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( NodeStartStreamResponse, PingStreamResponse, ) -from dify_graph.enums import WorkflowNodeExecutionStatus class TestAdvancedChatGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py index 83a6e0f231..a6d8598955 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline.py @@ -6,6 +6,8 @@ from types import SimpleNamespace from unittest import mock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus from core.app.apps.advanced_chat import generate_task_pipeline as pipeline_module from core.app.entities.app_invoke_entities import InvokeFrom @@ -17,11 +19,9 @@ from core.app.entities.queue_entities import ( QueueWorkflowSucceededEvent, ) from core.app.entities.task_entities import StreamEvent -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus from models.enums import MessageStatus from models.execution_extra_content import HumanInputContent -from models.model import EndUser +from models.model import AppMode, EndUser def _build_pipeline() -> pipeline_module.AdvancedChatAppGenerateTaskPipeline: @@ -137,7 +137,6 @@ def test_handle_workflow_paused_event_persists_human_input_extra_content() -> No actions=[], node_id="node-1", node_title="Approval", - form_token="token-1", resolved_default_values={}, ) event = QueueWorkflowPausedEvent(reasons=[reason], outputs={}, paused_nodes=["node-1"]) @@ -160,8 +159,8 @@ def test_resume_appends_chunks_to_paused_answer() -> None: task_id="task-1", ) queue_manager = SimpleNamespace(graph_runtime_state=None) - conversation = SimpleNamespace(id="conversation-1", mode="advanced-chat") - message = SimpleNamespace( + conversation = pipeline_module.ConversationSnapshot(id="conversation-1", mode=AppMode.ADVANCED_CHAT) + message = pipeline_module.MessageSnapshot( id="message-1", created_at=datetime(2024, 1, 1), query="hello", @@ -171,7 +170,7 @@ def test_resume_appends_chunks_to_paused_answer() -> None: user = EndUser() user.id = "user-1" user.session_id = "session-1" - workflow = SimpleNamespace(id="workflow-1", tenant_id="tenant-1", features_dict={}) + workflow = pipeline_module.WorkflowSnapshot(id="workflow-1", tenant_id="tenant-1", features_dict={}) pipeline = pipeline_module.AdvancedChatAppGenerateTaskPipeline( application_generate_entity=application_generate_entity, @@ -185,14 +184,33 @@ def test_resume_appends_chunks_to_paused_answer() -> None: draft_var_saver_factory=SimpleNamespace(), ) - pipeline._get_message = mock.Mock(return_value=message) + stored_message = SimpleNamespace( + id="message-1", + answer="before", + status=MessageStatus.PAUSED, + updated_at=None, + provider_response_latency=0, + message_tokens=0, + message_unit_price=0, + message_price_unit=0, + answer_tokens=0, + answer_unit_price=0, + answer_price_unit=0, + total_price=0, + currency="USD", + message_metadata=None, + invoke_from=InvokeFrom.WEB_APP, + from_account_id=None, + from_end_user_id="user-1", + ) + pipeline._get_message = mock.Mock(return_value=stored_message) pipeline._recorded_files = [] list(pipeline._handle_text_chunk_event(QueueTextChunkEvent(text="after"))) pipeline._save_message(session=mock.Mock()) - assert message.answer == "beforeafter" - assert message.status == MessageStatus.NORMAL + assert stored_message.answer == "beforeafter" + assert stored_message.status == MessageStatus.NORMAL def test_workflow_succeeded_emits_message_end_before_workflow_finished() -> None: diff --git a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py index 0a244b3fea..82b2e51019 100644 --- a/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/advanced_chat/test_generate_task_pipeline_core.py @@ -1,13 +1,19 @@ from __future__ import annotations from contextlib import contextmanager -from datetime import datetime from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig -from core.app.apps.advanced_chat.generate_task_pipeline import AdvancedChatAppGenerateTaskPipeline +from core.app.apps.advanced_chat.generate_task_pipeline import ( + AdvancedChatAppGenerateTaskPipeline, + ConversationSnapshot, + MessageSnapshot, + WorkflowSnapshot, +) from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom from core.app.entities.queue_entities import ( QueueAdvancedChatMessageEndEvent, @@ -42,11 +48,11 @@ from core.app.entities.task_entities import ( PingStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from libs.datetime_utils import naive_utc_now from models.enums import MessageStatus from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -72,15 +78,15 @@ def _make_pipeline(): workflow_run_id="run-id", ) - message = SimpleNamespace( + message = MessageSnapshot( id="message-id", query="hello", - created_at=datetime.utcnow(), + created_at=naive_utc_now(), status=MessageStatus.NORMAL, answer="", ) - conversation = SimpleNamespace(id="conv-id", mode=AppMode.ADVANCED_CHAT) - workflow = SimpleNamespace(id="workflow-id", tenant_id="tenant", features_dict={}) + conversation = ConversationSnapshot(id="conv-id", mode=AppMode.ADVANCED_CHAT) + workflow = WorkflowSnapshot(id="workflow-id", tenant_id="tenant", features_dict={}) user = EndUser(tenant_id="tenant", type="session", name="tester", session_id="session") pipeline = AdvancedChatAppGenerateTaskPipeline( @@ -166,7 +172,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -256,7 +262,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) iter_next = QueueIterationNextEvent( @@ -272,7 +278,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_start = QueueLoopStartEvent( @@ -280,7 +286,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_next = QueueLoopNextEvent( @@ -296,7 +302,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) @@ -311,7 +317,7 @@ class TestAdvancedChatGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_run_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -359,7 +365,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -369,7 +375,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -472,7 +478,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="title", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ) assert list(pipeline._handle_human_input_form_filled_event(filled_event)) == ["filled"] @@ -522,7 +528,7 @@ class TestAdvancedChatGenerateTaskPipeline: self.items = items graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -556,7 +562,7 @@ class TestAdvancedChatGenerateTaskPipeline: def test_handle_message_end_event_applies_output_moderation(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.handle_output_moderation_when_task_finished = lambda answer: "safe" @@ -590,7 +596,7 @@ class TestAdvancedChatGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.LLM, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py index 53f26d1592..7dc4358150 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_generator.py @@ -1,12 +1,12 @@ import contextlib import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError from core.app.apps.agent_chat.app_generator import AgentChatAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError class DummyAccount: diff --git a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py index 5603115b30..08250bc3b6 100644 --- a/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/agent_chat/test_agent_chat_app_runner.py @@ -1,10 +1,10 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey from core.agent.entities import AgentEntity from core.app.apps.agent_chat.app_runner import AgentChatAppRunner from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py index 3cdffbb4cd..68bcffb0e8 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_app_generator_and_runner.py @@ -2,6 +2,7 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.app.apps.chat.app_generator import ChatAppGenerator from core.app.apps.chat.app_runner import ChatAppRunner @@ -9,7 +10,6 @@ from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAnnotationReplyEvent from core.moderation.base import ModerationError -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py index 67b3777c40..f255d2c7df 100644 --- a/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py +++ b/api/tests/unit_tests/core/app/apps/chat/test_base_app_runner_multimodal.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest +from graphon.file import FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent from core.app.apps.base_app_queue_manager import PublishFrom from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueMessageFileEvent -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent from models.enums import CreatorUserRole diff --git a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py index b0789bbc1e..4a94a2b4f1 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py +++ b/api/tests/unit_tests/core/app/apps/common/test_graph_runtime_state_support.py @@ -1,15 +1,16 @@ from types import SimpleNamespace import pytest +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.common.graph_runtime_state_support import GraphRuntimeStateSupport -from dify_graph.runtime import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool def _make_state(workflow_run_id: str | None) -> GraphRuntimeState: - variable_pool = VariablePool(system_variables=SystemVariable(workflow_execution_id=workflow_run_id)) + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, build_system_variables(workflow_execution_id=workflow_run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py index 72430a3347..328cd12f12 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter.py @@ -1,8 +1,9 @@ from collections.abc import Mapping, Sequence +from graphon.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType +from graphon.variables.segments import ArrayFileSegment, FileSegment + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter -from dify_graph.file import FILE_MODEL_IDENTITY, File, FileTransferMethod, FileType -from dify_graph.variables.segments import ArrayFileSegment, FileSegment class TestWorkflowResponseConverterFetchFilesFromVariableValue: @@ -12,7 +13,6 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: """Create a test File object""" return File( id=file_id, - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related_123", @@ -223,7 +223,7 @@ class TestWorkflowResponseConverterFetchFilesFromVariableValue: assert len(result) == 1 file_dict = result[0] assert file_dict["id"] == "property_test" - assert file_dict["tenant_id"] == "test_tenant" + assert "tenant_id" not in file_dict assert file_dict["type"] == "document" assert file_dict["transfer_method"] == "local_file" assert file_dict["filename"] == "property_test.txt" diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py index 4ed7d73cd0..bc11bf4174 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_human_input.py @@ -1,16 +1,17 @@ from datetime import UTC, datetime from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueHumanInputFormFilledEvent, QueueHumanInputFormTimeoutEvent -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables def _build_converter(): - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py index 5879e8fb9b..c9e146ff12 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_resumption.py @@ -1,15 +1,16 @@ from types import SimpleNamespace +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables def _build_converter() -> WorkflowResponseConverter: """Construct a minimal WorkflowResponseConverter for testing.""" - system_variables = SystemVariable( + system_variables = build_system_variables( files=[], user_id="user-1", app_id="app-1", diff --git a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py index 374af5ddc4..0fde7565d2 100644 --- a/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py +++ b/api/tests/unit_tests/core/app/apps/common/test_workflow_response_converter_truncation.py @@ -10,6 +10,8 @@ from typing import Any from unittest.mock import Mock import pytest +from graphon.entities import WorkflowStartReason +from graphon.enums import BuiltinNodeTypes from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -24,9 +26,7 @@ from core.app.entities.queue_entities import ( QueueNodeStartedEvent, QueueNodeSucceededEvent, ) -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from models import Account from models.model import AppMode @@ -54,7 +54,7 @@ class TestWorkflowResponseConverter: mock_user.name = "Test User" mock_user.email = "test@example.com" - system_variables = SystemVariable(workflow_id="wf-id", workflow_execution_id="initial-run-id") + system_variables = build_system_variables(workflow_id="wf-id", workflow_execution_id="initial-run-id") return WorkflowResponseConverter( application_generate_entity=mock_entity, user=mock_user, @@ -451,9 +451,9 @@ class TestWorkflowResponseConverterServiceApiTruncation: account.id = "test_user_id" return account - def create_test_system_variables(self) -> SystemVariable: + def create_test_system_variables(self): """Create test system variables.""" - return SystemVariable() + return build_system_variables() def create_test_converter(self, invoke_from: InvokeFrom) -> WorkflowResponseConverter: """Create WorkflowResponseConverter with specified invoke_from.""" diff --git a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py index 51f33bac35..619d66085a 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_app_runner.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent import core.app.apps.completion.app_runner as module from core.app.apps.completion.app_runner import CompletionAppRunner from core.moderation.base import ModerationError -from dify_graph.model_runtime.entities.message_entities import ImagePromptMessageContent @pytest.fixture diff --git a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py index 2714757353..96af9fbdee 100644 --- a/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/completion/test_completion_completion_app_generator.py @@ -3,13 +3,13 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from pydantic import ValidationError import core.app.apps.completion.app_generator as module from core.app.apps.completion.app_generator import CompletionAppGenerator from core.app.apps.exc import GenerateTaskStoppedError from core.app.entities.app_invoke_entities import InvokeFrom -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError from services.errors.app import MoreLikeThisDisabledError from services.errors.message import MessageNotExistsError diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py index 94ed8166b9..6cdcab29ab 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.pipeline.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( AppStreamResponse, @@ -10,7 +12,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus def test_convert_blocking_full_and_simple_response(): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py index 72f7552bd1..4fe82efcb3 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_queue_manager.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult import core.app.apps.pipeline.pipeline_queue_manager as module from core.app.apps.base_app_queue_manager import PublishFrom @@ -13,7 +14,6 @@ from core.app.entities.queue_entities import ( QueueWorkflowPartialSuccessEvent, QueueWorkflowSucceededEvent, ) -from dify_graph.model_runtime.entities.llm_entities import LLMResult def test_publish_sets_stop_listen_and_raises_on_stopped(mocker): diff --git a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py index eec95b7f39..ab70996f0a 100644 --- a/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py +++ b/api/tests/unit_tests/core/app/apps/pipeline/test_pipeline_runner.py @@ -22,11 +22,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.graph_events import GraphRunFailedEvent import core.app.apps.pipeline.pipeline_runner as module from core.app.apps.pipeline.pipeline_runner import PipelineRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.graph_events import GraphRunFailedEvent def _build_app_generate_entity() -> SimpleNamespace: @@ -284,7 +284,12 @@ def test_run_normal_path_builds_graph(mocker): return_value=SimpleNamespace(belong_to_node_id="start", variable="input1"), ) mocker.patch.object(module, "RAGPipelineVariableInput", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) - mocker.patch.object(module, "VariablePool", side_effect=lambda **kwargs: SimpleNamespace(**kwargs)) + + class FakeVariablePool: + def add(self, selector, value): + return None + + mocker.patch.object(module, "VariablePool", return_value=FakeVariablePool()) workflow_entry = MagicMock() workflow_entry.graph_engine = MagicMock() diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py index a3ced02394..6167be3bbd 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_generator.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_generator.py @@ -1,9 +1,7 @@ -from unittest.mock import MagicMock - import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.apps.base_app_generator import BaseAppGenerator -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_validate_inputs_with_zero(): @@ -403,11 +401,11 @@ class TestBaseAppGeneratorExtras: monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mapping", - lambda mapping, tenant_id, config, strict_type_validation=False: "file-object", + lambda mapping, tenant_id, config, strict_type_validation=False, access_controller=None: "file-object", ) monkeypatch.setattr( "core.app.apps.base_app_generator.file_factory.build_from_mappings", - lambda mappings, tenant_id, config: ["file-1", "file-2"], + lambda mappings, tenant_id, config, access_controller=None: ["file-1", "file-2"], ) user_inputs = { @@ -478,8 +476,9 @@ class TestBaseAppGeneratorExtras: assert converted[1] == "event: ping\n\n" def test_get_draft_var_saver_factory_debugger(self): + from graphon.enums import BuiltinNodeTypes + from core.app.entities.app_invoke_entities import InvokeFrom - from dify_graph.enums import BuiltinNodeTypes from models import Account base_app_generator = BaseAppGenerator() @@ -489,7 +488,6 @@ class TestBaseAppGeneratorExtras: factory = base_app_generator._get_draft_var_saver_factory(InvokeFrom.DEBUGGER, account) saver = factory( - session=MagicMock(), app_id="app-id", node_id="node-id", node_type=BuiltinNodeTypes.START, diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py index c6dc20ffc6..842d14bbd2 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_queue_manager.py @@ -59,3 +59,18 @@ class TestBaseAppQueueManager: bad = SimpleNamespace(_sa_instance_state=True) with pytest.raises(TypeError): manager._check_for_sqlalchemy_models(bad) + + def test_stop_listen_defers_graph_runtime_state_cleanup_until_listener_exits(self): + with patch("core.app.apps.base_app_queue_manager.redis_client") as mock_redis: + mock_redis.setex.return_value = True + mock_redis.get.return_value = None + manager = DummyQueueManager(task_id="t1", user_id="u1", invoke_from=InvokeFrom.SERVICE_API) + + runtime_state = SimpleNamespace(name="runtime-state") + manager.graph_runtime_state = runtime_state + + manager.stop_listen() + + assert manager.graph_runtime_state is runtime_state + assert list(manager.listen()) == [] + assert manager.graph_runtime_state is None diff --git a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py index aabeb54553..1dee7fdab6 100644 --- a/api/tests/unit_tests/core/app/apps/test_base_app_runner.py +++ b/api/tests/unit_tests/core/app/apps/test_base_app_runner.py @@ -4,6 +4,15 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + PromptMessageRole, + TextPromptMessageContent, +) +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import InvokeBadRequestError from core.app.app_config.entities import ( AdvancedChatMessageEntity, @@ -14,15 +23,6 @@ from core.app.app_config.entities import ( from core.app.apps.base_app_runner import AppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueAgentMessageEvent, QueueLLMChunkEvent, QueueMessageEndEvent -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - PromptMessageRole, - TextPromptMessageContent, -) -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/apps/test_pause_resume.py b/api/tests/unit_tests/core/app/apps/test_pause_resume.py index 2f73a8cda8..a126bc85f7 100644 --- a/api/tests/unit_tests/core/app/apps/test_pause_resume.py +++ b/api/tests/unit_tests/core/app/apps/test_pause_resume.py @@ -3,33 +3,34 @@ import time from types import ModuleType, SimpleNamespace from typing import Any -import dify_graph.nodes.human_input.entities # noqa: F401 -from core.app.apps.advanced_chat import app_generator as adv_app_gen_module -from core.app.apps.workflow import app_generator as wf_app_gen_module -from core.app.entities.app_invoke_entities import InvokeFrom -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.base_node_data import BaseNodeData, RetryConfig -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_events import ( +import graphon.nodes.human_input.entities # noqa: F401 +from graphon.entities import WorkflowStartReason +from graphon.entities.base_node_data import BaseNodeData, RetryConfig +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphEngineEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult, PauseRequestedEvent -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.base.node import Node -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult, PauseRequestedEvent +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.base.node import Node +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.start.entities import StartNodeData +from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.apps.advanced_chat import app_generator as adv_app_gen_module +from core.app.apps.workflow import app_generator as wf_app_gen_module +from core.app.entities.app_invoke_entities import InvokeFrom +from core.workflow.node_factory import DifyNodeFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if "core.ops.ops_trace_manager" not in sys.modules: @@ -162,11 +163,11 @@ def _build_graph(runtime_state: GraphRuntimeState, *, pause_on: str | None) -> G def _build_runtime_state(run_id: str) -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), + system_variables=build_system_variables(user_id="user", app_id="app", workflow_id="workflow"), user_inputs={}, conversation_variables=[], ) - variable_pool.system_variables.workflow_execution_id = run_id + variable_pool.add(("sys", "workflow_run_id"), run_id) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py index 3f1dd14569..de5bca161c 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_core.py @@ -1,9 +1,28 @@ from __future__ import annotations -from datetime import datetime +from datetime import UTC, datetime from types import SimpleNamespace import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import ( + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, + NodeRunAgentLogEvent, + NodeRunExceptionEvent, + NodeRunFailedEvent, + NodeRunIterationSucceededEvent, + NodeRunLoopFailedEvent, + NodeRunRetryEvent, + NodeRunStartedEvent, + NodeRunStreamChunkEvent, + NodeRunSucceededEvent, +) +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.variables import StringVariable from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom @@ -11,25 +30,16 @@ from core.app.entities.queue_entities import ( QueueAgentLogEvent, QueueIterationCompletedEvent, QueueLoopCompletedEvent, + QueueNodeExceptionEvent, + QueueNodeFailedEvent, + QueueNodeRetryEvent, + QueueNodeSucceededEvent, QueueTextChunkEvent, QueueWorkflowPausedEvent, QueueWorkflowStartedEvent, QueueWorkflowSucceededEvent, ) -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunAgentLogEvent, - NodeRunIterationSucceededEvent, - NodeRunLoopFailedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables class TestWorkflowBasedAppRunner: @@ -44,7 +54,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -78,12 +88,12 @@ class TestWorkflowBasedAppRunner: workflow = SimpleNamespace(environment_variables=[], graph_dict={}) with pytest.raises(ValueError, match="Neither single_iteration_run nor single_loop_run"): - runner._prepare_single_node_execution(workflow, None, None) + runner._prepare_single_node_execution(workflow, None, None, user_id="00000000-0000-0000-0000-000000000001") def test_get_graph_and_variable_pool_for_single_node_run(self, monkeypatch): runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) @@ -126,11 +136,102 @@ class TestWorkflowBasedAppRunner: graph_runtime_state=graph_runtime_state, node_type_filter_key="iteration_id", node_type_label="iteration", + user_id="00000000-0000-0000-0000-000000000001", ) assert graph is not None assert variable_pool is graph_runtime_state.variable_pool + def test_get_graph_and_variable_pool_preloads_constructor_variables_before_graph_init(self, monkeypatch): + variable_loader = SimpleNamespace( + load_variables=lambda selectors: ( + [ + StringVariable( + name="conversation_id", + value="conv-1", + selector=["sys", "conversation_id"], + ) + ] + if selectors + else [] + ) + ) + runner = WorkflowBasedAppRunner( + queue_manager=SimpleNamespace(), + variable_loader=variable_loader, + app_id="app", + ) + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + + workflow = SimpleNamespace( + tenant_id="tenant", + id="workflow", + graph_dict={ + "nodes": [ + {"id": "loop-node", "data": {"type": "loop", "version": "1", "title": "Loop"}}, + { + "id": "llm-child", + "data": { + "type": "llm", + "version": "1", + "loop_id": "loop-node", + "memory": object(), + }, + }, + ], + "edges": [], + }, + ) + + class _LoopNodeCls: + @staticmethod + def extract_variable_selector_to_variable_mapping(graph_config, config): + return {} + + def _validate_node_config(value): + return {"id": value["id"], "data": SimpleNamespace(**value["data"])} + + def _graph_init(**kwargs): + variable_pool = graph_runtime_state.variable_pool + assert variable_pool.get(["sys", "conversation_id"]) is not None + return SimpleNamespace() + + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.NodeConfigDictAdapter.validate_python", + _validate_node_config, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.Graph.init", + _graph_init, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.resolve_workflow_node_class", + lambda **_kwargs: _LoopNodeCls, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.load_into_variable_pool", + lambda **kwargs: None, + ) + monkeypatch.setattr( + "core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool", + lambda **kwargs: None, + ) + + graph, variable_pool = runner._get_graph_and_variable_pool_for_single_node_run( + workflow=workflow, + node_id="loop-node", + user_inputs={}, + graph_runtime_state=graph_runtime_state, + node_type_filter_key="loop_id", + node_type_label="loop", + ) + + assert graph is not None + assert variable_pool.get(["sys", "conversation_id"]).value == "conv-1" + def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch): published: list[object] = [] @@ -140,7 +241,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) graph_runtime_state.register_paused_node("node-1") @@ -183,7 +284,7 @@ class TestWorkflowBasedAppRunner: runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), start_at=0.0, ) workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) @@ -195,7 +296,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.START, node_title="Start", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), ), ) runner._handle_event( @@ -232,7 +333,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Iter", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={"ok": True}, metadata={}, @@ -246,7 +347,7 @@ class TestWorkflowBasedAppRunner: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="Loop", - start_at=datetime.utcnow(), + start_at=datetime.now(UTC), inputs={}, outputs={}, metadata={}, @@ -259,3 +360,87 @@ class TestWorkflowBasedAppRunner: assert any(isinstance(event, QueueAgentLogEvent) for event in published) assert any(isinstance(event, QueueIterationCompletedEvent) for event in published) assert any(isinstance(event, QueueLoopCompletedEvent) for event in published) + + @pytest.mark.parametrize( + ("event_factory", "queue_event_cls"), + [ + ( + lambda result, start_at, finished_at: NodeRunSucceededEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + node_run_result=result, + ), + QueueNodeSucceededEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunFailedEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeFailedEvent, + ), + ( + lambda result, start_at, finished_at: NodeRunExceptionEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + start_at=start_at, + finished_at=finished_at, + error="boom", + node_run_result=result, + ), + QueueNodeExceptionEvent, + ), + ( + lambda result, start_at, _finished_at: NodeRunRetryEvent( + id="exec", + node_id="node", + node_type=BuiltinNodeTypes.START, + node_title="Start", + start_at=start_at, + error="boom", + retry_index=1, + node_run_result=result, + ), + QueueNodeRetryEvent, + ), + ], + ) + def test_handle_start_node_result_events_project_outputs(self, event_factory, queue_event_cls): + published: list[object] = [] + + class _QueueManager: + def publish(self, event, publish_from): + published.append(event) + + runner = WorkflowBasedAppRunner(queue_manager=_QueueManager(), app_id="app") + graph_runtime_state = GraphRuntimeState( + variable_pool=VariablePool(system_variables=default_system_variables()), + start_at=0.0, + ) + workflow_entry = SimpleNamespace(graph_engine=SimpleNamespace(graph_runtime_state=graph_runtime_state)) + started_at = datetime.now(UTC) + finished_at = datetime.now(UTC) + result = NodeRunResult( + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + "conversation.session_id": "session-1", + }, + ) + + runner._handle_event(workflow_entry, event_factory(result, started_at, finished_at)) + + queue_event = published[-1] + assert isinstance(queue_event, queue_event_cls) + assert queue_event.outputs == {"question": "hello"} diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py index 1388279221..aa789d9ff3 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_notifications.py @@ -1,11 +1,11 @@ from unittest.mock import MagicMock import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.queue_entities import QueueWorkflowPausedEvent -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.graph_events.graph import GraphRunPausedEvent class _DummyQueueManager: diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py index 178e26118e..9e30faecf2 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_app_runner_single_node.py @@ -4,20 +4,20 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest +from graphon.entities.graph_config import NodeConfigDictAdapter +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.apps.workflow_app_runner import WorkflowBasedAppRunner from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables from models.workflow import Workflow def _make_graph_state(): variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -100,6 +100,7 @@ def test_run_uses_single_node_execution_branch( workflow=workflow, single_iteration_run=single_iteration_run, single_loop_run=single_loop_run, + user_id="user", ) init_graph.assert_not_called() @@ -158,6 +159,7 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None: graph_runtime_state=graph_runtime_state, node_type_filter_key="loop_id", node_type_label="loop", + user_id="00000000-0000-0000-0000-000000000001", ) assert seen_configs == [workflow.graph_dict["nodes"][0]] diff --git a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py index 65c6bd6654..8a717e1dcc 100644 --- a/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py +++ b/api/tests/unit_tests/core/app/apps/test_workflow_pause_events.py @@ -3,6 +3,11 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowStartReason +from graphon.entities.pause_reason import HumanInputRequired +from graphon.graph_events import GraphRunPausedEvent +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType from core.app.apps.common import workflow_response_converter from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter @@ -10,13 +15,9 @@ from core.app.apps.workflow.app_runner import WorkflowAppRunner from core.app.entities.app_invoke_entities import InvokeFrom from core.app.entities.queue_entities import QueueWorkflowPausedEvent from core.app.entities.task_entities import HumanInputRequiredResponse, WorkflowPauseStreamResponse -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph_events.graph import GraphRunPausedEvent -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from models.account import Account +from models.human_input import RecipientType class _RecordingWorkflowAppRunner(WorkflowAppRunner): @@ -74,7 +75,6 @@ def test_graph_run_paused_event_emits_queue_pause_event(): actions=[], node_id="node-human", node_title="Human Step", - form_token="tok", ) event = GraphRunPausedEvent(reasons=[reason], outputs={"foo": "bar"}) workflow_entry = SimpleNamespace( @@ -98,7 +98,7 @@ def _build_converter(): invoke_from=InvokeFrom.SERVICE_API, app_config=SimpleNamespace(app_id="app-id", tenant_id="tenant-id"), ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="user", app_id="app-id", workflow_id="workflow-id", @@ -128,7 +128,21 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon class _FakeSession: def execute(self, _stmt): - return [("form-1", expiration_time)] + return [("form-1", expiration_time, '{"display_in_ui": true}')] + + def scalars(self, _stmt): + return [ + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.CONSOLE, + access_token="console-token", + ), + SimpleNamespace( + form_id="form-1", + recipient_type=RecipientType.BACKSTAGE, + access_token="backstage-token", + ), + ] def __enter__(self): return self @@ -146,10 +160,8 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="field", default=None), ], actions=[UserAction(id="approve", title="Approve")], - display_in_ui=True, node_id="node-id", node_title="Human Step", - form_token="token", ) queue_event = QueueWorkflowPausedEvent( reasons=[reason], @@ -170,7 +182,6 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert pause_resp.data.paused_nodes == ["node-id"] assert pause_resp.data.outputs == {} assert pause_resp.data.reasons[0]["form_id"] == "form-1" - assert pause_resp.data.reasons[0]["display_in_ui"] is True assert isinstance(responses[0], HumanInputRequiredResponse) hi_resp = responses[0] @@ -180,4 +191,5 @@ def test_queue_workflow_paused_event_to_stream_responses(monkeypatch: pytest.Mon assert hi_resp.data.inputs[0].output_variable_name == "field" assert hi_resp.data.actions[0].id == "approve" assert hi_resp.data.display_in_ui is True + assert hi_resp.data.form_token == "backstage-token" assert hi_resp.data.expiration_time == int(expiration_time.timestamp()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py index 62e94a7580..b768e813bd 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_response_converter.py @@ -1,5 +1,7 @@ from collections.abc import Generator +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus + from core.app.apps.workflow.generate_response_converter import WorkflowAppGenerateResponseConverter from core.app.entities.task_entities import ( ErrorStreamResponse, @@ -9,7 +11,6 @@ from core.app.entities.task_entities import ( WorkflowAppBlockingResponse, WorkflowAppStreamResponse, ) -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus class TestWorkflowGenerateResponseConverter: diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py index 5b23e71035..29df903aa8 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline.py @@ -2,16 +2,18 @@ import time from contextlib import contextmanager from unittest.mock import MagicMock +from graphon.entities import WorkflowStartReason +from graphon.runtime import GraphRuntimeState + from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.entities.queue_entities import QueueWorkflowStartedEvent -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from models.account import Account from models.model import AppMode +from tests.workflow_test_utils import build_test_variable_pool def _build_workflow_app_config() -> WorkflowUIBasedAppConfig: @@ -37,11 +39,7 @@ def _build_generate_entity(run_id: str) -> WorkflowAppGenerateEntity: def _build_runtime_state(run_id: str) -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable(workflow_execution_id=run_id), - user_inputs={}, - conversation_variables=[], - ) + variable_pool = build_test_variable_pool(variables=build_system_variables(workflow_execution_id=run_id)) return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) diff --git a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py index f35710d207..dabd2594b4 100644 --- a/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/apps/workflow/test_generate_task_pipeline_core.py @@ -1,10 +1,11 @@ from __future__ import annotations from contextlib import contextmanager -from datetime import datetime from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import AppAdditionalFeatures, WorkflowUIBasedAppConfig from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline @@ -44,11 +45,11 @@ from core.app.entities.task_entities import ( WorkflowStartStreamResponse, ) from core.base.tts.app_generator_tts_publisher import AudioTrunk -from dify_graph.enums import BuiltinNodeTypes, WorkflowExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables, system_variables_to_mapping +from libs.datetime_utils import naive_utc_now from models.enums import CreatorUserRole from models.model import AppMode, EndUser +from tests.workflow_test_utils import build_test_variable_pool def _make_pipeline(): @@ -164,7 +165,7 @@ class TestWorkflowGenerateTaskPipeline: def test_handle_workflow_started_event_sets_run_id(self, monkeypatch): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_start_to_stream_response = lambda **kwargs: "started" @@ -191,7 +192,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -205,7 +206,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=build_test_variable_pool(variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -244,7 +245,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -257,7 +258,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._workflow_response_converter.workflow_finish_to_stream_response = lambda **kwargs: "finish" @@ -302,7 +303,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) iter_next = QueueIterationNextEvent( @@ -318,7 +319,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_start = QueueLoopStartEvent( @@ -326,7 +327,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) loop_next = QueueLoopNextEvent( @@ -342,7 +343,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="LLM", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), node_run_index=1, ) filled_event = QueueHumanInputFormFilledEvent( @@ -358,7 +359,7 @@ class TestWorkflowGenerateTaskPipeline: node_id="node", node_type=BuiltinNodeTypes.LLM, node_title="title", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ) agent_event = QueueAgentLogEvent( id="log", @@ -451,7 +452,7 @@ class TestWorkflowGenerateTaskPipeline: ) assert pipeline._created_by_role == CreatorUserRole.END_USER - assert pipeline._workflow_system_variables.user_id == "session-id" + assert system_variables_to_mapping(pipeline._workflow_system_variables)["user_id"] == "session-id" def test_process_returns_stream_and_blocking_variants(self): pipeline = _make_pipeline() @@ -647,7 +648,7 @@ class TestWorkflowGenerateTaskPipeline: node_title="title", node_type=BuiltinNodeTypes.LLM, node_run_index=1, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), provider_type="provider", provider_id="provider-id", error="error", @@ -659,7 +660,7 @@ class TestWorkflowGenerateTaskPipeline: node_title="title", node_type=BuiltinNodeTypes.LLM, node_run_index=1, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), provider_type="provider", provider_id="provider-id", ) @@ -684,7 +685,7 @@ class TestWorkflowGenerateTaskPipeline: node_execution_id="exec-id", node_id="node", node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), + start_at=naive_utc_now(), inputs={}, outputs={}, process_data={}, @@ -699,7 +700,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) @@ -727,7 +728,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline = _make_pipeline() pipeline._workflow_execution_id = "run-id" pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._handle_ping_event = lambda event, **kwargs: iter(["ping"]) @@ -743,7 +744,7 @@ class TestWorkflowGenerateTaskPipeline: def test_process_stream_response_main_match_paths_and_cleanup(self): pipeline = _make_pipeline() pipeline._graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(workflow_execution_id="run-id")), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-id")), start_at=0.0, ) pipeline._base_task_pipeline.queue_manager.listen = lambda: iter( @@ -815,7 +816,7 @@ class TestWorkflowGenerateTaskPipeline: pipeline._save_workflow_app_log(session=_Session(), workflow_run_id=None) assert len(added) == count_before - def test_save_output_for_event_writes_draft_variables(self, monkeypatch): + def test_save_output_for_event_writes_draft_variables(self): pipeline = _make_pipeline() saver_calls: list[tuple[object, object]] = [] captured_factory_args: dict[str, object] = {} @@ -828,36 +829,14 @@ class TestWorkflowGenerateTaskPipeline: captured_factory_args.update(kwargs) return _Saver() - class _Begin: - def __enter__(self): - return None - - def __exit__(self, exc_type, exc, tb): - return False - - class _Session: - def __init__(self, *args, **kwargs): - _ = args, kwargs - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return _Begin() - pipeline._draft_var_saver_factory = _factory - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.Session", _Session) - monkeypatch.setattr("core.app.apps.workflow.generate_task_pipeline.db", SimpleNamespace(engine=object())) event = QueueNodeSucceededEvent( node_execution_id="exec-id", node_id="node-id", node_type=BuiltinNodeTypes.START, in_loop_id="loop-id", - start_at=datetime.utcnow(), + start_at=naive_utc_now(), process_data={"k": "v"}, outputs={"out": 1}, ) diff --git a/api/tests/unit_tests/core/app/entities/test_task_entities.py b/api/tests/unit_tests/core/app/entities/test_task_entities.py index 8ecab3199c..014a0cba72 100644 --- a/api/tests/unit_tests/core/app/entities/test_task_entities.py +++ b/api/tests/unit_tests/core/app/entities/test_task_entities.py @@ -1,10 +1,11 @@ +from graphon.enums import WorkflowNodeExecutionStatus + from core.app.entities.task_entities import ( NodeFinishStreamResponse, NodeRetryStreamResponse, NodeStartStreamResponse, StreamEvent, ) -from dify_graph.enums import WorkflowNodeExecutionStatus class TestTaskEntities: diff --git a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py index bdc889d941..a78c1b428f 100644 --- a/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_conversation_variable_persist_layer.py @@ -1,18 +1,18 @@ from collections.abc import Sequence -from datetime import datetime from unittest.mock import Mock +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.command_channels import CommandChannel +from graphon.graph_events import NodeRunSucceededEvent, NodeRunVariableUpdatedEvent +from graphon.node_events import NodeRunResult +from graphon.runtime import ReadOnlyGraphRuntimeState +from graphon.variables import StringVariable +from graphon.variables.segments import Segment, StringSegment + from core.app.layers.conversation_variable_persist_layer import ConversationVariablePersistenceLayer -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.protocols.command_channel import CommandChannel -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyGraphRuntimeState -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import StringVariable -from dify_graph.variables.segments import Segment +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID +from libs.datetime_utils import naive_utc_now class MockReadOnlyVariablePool: @@ -36,31 +36,38 @@ def _build_graph_runtime_state( conversation_id: str | None = None, ) -> ReadOnlyGraphRuntimeState: graph_runtime_state = Mock(spec=ReadOnlyGraphRuntimeState) + if conversation_id is not None: + variable_pool._variables[("sys", SystemVariableKey.CONVERSATION_ID.value)] = StringSegment( + value=conversation_id + ) graph_runtime_state.variable_pool = variable_pool - graph_runtime_state.system_variable = SystemVariable(conversation_id=conversation_id).as_view() return graph_runtime_state -def _build_node_run_succeeded_event( - *, - node_type: NodeType, - outputs: dict[str, object] | None = None, - process_data: dict[str, object] | None = None, -) -> NodeRunSucceededEvent: +def _build_node_run_succeeded_event() -> NodeRunSucceededEvent: return NodeRunSucceededEvent( id="node-exec-id", node_id="assigner", - node_type=node_type, - start_at=datetime.utcnow(), + node_type=BuiltinNodeTypes.LLM, + start_at=naive_utc_now(), node_run_result=NodeRunResult( status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs=outputs or {}, - process_data=process_data or {}, + outputs={}, + process_data={}, ), ) -def test_persists_conversation_variables_from_assigner_output(): +def _build_variable_updated_event(variable: StringVariable) -> NodeRunVariableUpdatedEvent: + return NodeRunVariableUpdatedEvent( + id="node-exec-id", + node_id="assigner", + node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, + variable=variable, + ) + + +def test_persists_conversation_variables_from_variable_update_event(): conversation_id = "conv-123" variable = StringVariable( id="var-1", @@ -68,55 +75,26 @@ def test_persists_conversation_variables_from_assigner_output(): value="updated", selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(variable.selector, variable)] - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(variable) layer.on_event(event) updater.update.assert_called_once_with(conversation_id=conversation_id, variable=variable) - updater.flush.assert_called_once() -def test_skips_when_outputs_missing(): +def test_skips_non_variable_update_events(): conversation_id = "conv-456" - variable = StringVariable( - id="var-2", - name="name", - value="updated", - selector=[CONVERSATION_VARIABLE_NODE_ID, "name"], - ) - - variable_pool = MockReadOnlyVariablePool({(CONVERSATION_VARIABLE_NODE_ID, "name"): variable}) - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER) + event = _build_node_run_succeeded_event() layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() - - -def test_skips_non_assigner_nodes(): - updater = Mock() - layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool()), Mock(spec=CommandChannel)) - - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.LLM) - layer.on_event(event) - - updater.update.assert_not_called() - updater.flush.assert_not_called() def test_skips_non_conversation_variables(): @@ -127,18 +105,11 @@ def test_skips_non_conversation_variables(): value="updated", selector=["environment", "name"], ) - process_data = common_helpers.set_updated_variables( - {}, [common_helpers.variable_to_processed_data(non_conversation_variable.selector, non_conversation_variable)] - ) - - variable_pool = MockReadOnlyVariablePool() - updater = Mock() layer = ConversationVariablePersistenceLayer(updater) - layer.initialize(_build_graph_runtime_state(variable_pool, conversation_id), Mock(spec=CommandChannel)) + layer.initialize(_build_graph_runtime_state(MockReadOnlyVariablePool(), conversation_id), Mock(spec=CommandChannel)) - event = _build_node_run_succeeded_event(node_type=BuiltinNodeTypes.VARIABLE_ASSIGNER, process_data=process_data) + event = _build_variable_updated_event(non_conversation_variable) layer.on_event(event) updater.update.assert_not_called() - updater.flush.assert_not_called() diff --git a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py index 035f0ee05c..035e64325b 100644 --- a/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_pause_state_persist_layer.py @@ -4,6 +4,17 @@ from time import time from unittest.mock import Mock import pytest +from graphon.entities.pause_reason import SchedulingPause +from graphon.graph_engine.entities.commands import GraphEngineCommand +from graphon.graph_engine.layers.base import GraphEngineLayerNotInitializedError +from graphon.graph_events import ( + GraphRunFailedEvent, + GraphRunPausedEvent, + GraphRunStartedEvent, + GraphRunSucceededEvent, +) +from graphon.runtime import ReadOnlyVariablePool +from graphon.variables.segments import Segment from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity @@ -13,17 +24,7 @@ from core.app.layers.pause_state_persist_layer import ( _AdvancedChatAppGenerateEntityWrapper, _WorkflowGenerateEntityWrapper, ) -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.graph_engine.entities.commands import GraphEngineCommand -from dify_graph.graph_engine.layers.base import GraphEngineLayerNotInitializedError -from dify_graph.graph_events.graph import ( - GraphRunFailedEvent, - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) -from dify_graph.runtime.graph_runtime_state_protocol import ReadOnlyVariablePool -from dify_graph.variables.segments import Segment +from core.workflow.system_variables import SystemVariableKey from models.model import AppMode from repositories.factory import DifyAPIRepositoryFactory @@ -51,17 +52,6 @@ class TestDataFactory: return GraphRunFailedEvent(error=error, exceptions_count=exceptions_count) -class MockSystemVariableReadOnlyView: - """Minimal read-only system variable view for testing.""" - - def __init__(self, workflow_execution_id: str | None = None) -> None: - self._workflow_execution_id = workflow_execution_id - - @property - def workflow_execution_id(self) -> str | None: - return self._workflow_execution_id - - class MockReadOnlyVariablePool: """Mock implementation of ReadOnlyVariablePool for testing.""" @@ -76,13 +66,14 @@ class MockReadOnlyVariablePool: return None mock_segment = Mock(spec=Segment) mock_segment.value = value + mock_segment.text = value if isinstance(value, str) else None return mock_segment def get_all_by_node(self, node_id: str) -> dict[str, object]: return {key: value for (nid, key), value in self._variables.items() if nid == node_id} def get_by_prefix(self, prefix: str) -> dict[str, object]: - return {f"{nid}.{key}": value for (nid, key), value in self._variables.items() if nid.startswith(prefix)} + return {key: value for (nid, key), value in self._variables.items() if nid == prefix} class MockReadOnlyGraphRuntimeState: @@ -105,12 +96,10 @@ class MockReadOnlyGraphRuntimeState: self._ready_queue_size = ready_queue_size self._exceptions_count = exceptions_count self._outputs = outputs or {} - self._variable_pool = MockReadOnlyVariablePool(variables) - self._system_variable = MockSystemVariableReadOnlyView(workflow_execution_id) - - @property - def system_variable(self) -> MockSystemVariableReadOnlyView: - return self._system_variable + resolved_variables = dict(variables or {}) + if workflow_execution_id is not None: + resolved_variables[("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value)] = workflow_execution_id + self._variable_pool = MockReadOnlyVariablePool(resolved_variables) @property def variable_pool(self) -> ReadOnlyVariablePool: @@ -161,7 +150,9 @@ class MockReadOnlyGraphRuntimeState: "exceptions_count": self._exceptions_count, "outputs": self._outputs, "variables": {f"{k[0]}.{k[1]}": v for k, v in self._variable_pool._variables.items()}, - "workflow_execution_id": self._system_variable.workflow_execution_id, + "workflow_execution_id": self._variable_pool._variables.get( + ("sys", SystemVariableKey.WORKFLOW_EXECUTION_ID.value) + ), } ) diff --git a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py index c6d820dbc9..95931f4f8b 100644 --- a/api/tests/unit_tests/core/app/layers/test_suspend_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_suspend_layer.py @@ -1,5 +1,6 @@ +from graphon.graph_events import GraphRunPausedEvent + from core.app.layers.suspend_layer import SuspendLayer -from dify_graph.graph_events.graph import GraphRunPausedEvent class TestSuspendLayer: diff --git a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py index c87eec1508..7cf6eb4f31 100644 --- a/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_timeslice_layer.py @@ -1,7 +1,8 @@ from unittest.mock import Mock, patch +from graphon.graph_engine.entities.commands import CommandType, GraphEngineCommand + from core.app.layers.timeslice_layer import TimeSliceLayer -from dify_graph.graph_engine.entities.commands import CommandType, GraphEngineCommand from services.workflow.entities import WorkflowScheduleCFSPlanEntity from services.workflow.scheduler import SchedulerCommand diff --git a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py index f9755061d6..aa9285789b 100644 --- a/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py +++ b/api/tests/unit_tests/core/app/layers/test_trigger_post_layer.py @@ -2,8 +2,11 @@ from datetime import UTC, datetime, timedelta from types import SimpleNamespace from unittest.mock import Mock, patch +from graphon.graph_events import GraphRunFailedEvent, GraphRunSucceededEvent +from graphon.runtime import VariablePool + from core.app.layers.trigger_post_layer import TriggerPostLayer -from dify_graph.graph_events.graph import GraphRunFailedEvent, GraphRunSucceededEvent +from core.workflow.system_variables import build_system_variables from models.enums import WorkflowTriggerStatus @@ -19,7 +22,7 @@ class TestTriggerPostLayer: ) runtime_state = SimpleNamespace( outputs={"answer": "ok"}, - system_variable=SimpleNamespace(workflow_execution_id="run-1"), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), total_tokens=12, ) @@ -58,7 +61,7 @@ class TestTriggerPostLayer: def test_on_event_handles_missing_trigger_log(self): runtime_state = SimpleNamespace( outputs={}, - system_variable=SimpleNamespace(workflow_execution_id="run-1"), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), total_tokens=0, ) @@ -89,7 +92,7 @@ class TestTriggerPostLayer: def test_on_event_ignores_non_status_events(self): runtime_state = SimpleNamespace( outputs={}, - system_variable=SimpleNamespace(workflow_execution_id="run-1"), + variable_pool=VariablePool(system_variables=build_system_variables(workflow_execution_id="run-1")), total_tokens=0, ) diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py index e070eb06fd..58aa7d7478 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_based_generate_task_pipeline.py @@ -2,11 +2,11 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.entities.queue_entities import QueueErrorEvent from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline from core.errors.error import QuotaExceededError -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from models.enums import MessageStatus diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py index 13fbca6e26..4aaa10a81a 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline.py @@ -2,6 +2,8 @@ from types import SimpleNamespace from unittest.mock import ANY, Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult +from graphon.model_runtime.entities.message_entities import TextPromptMessageContent from core.app.apps.base_app_queue_manager import AppQueueManager from core.app.entities.app_invoke_entities import ChatAppGenerateEntity @@ -26,8 +28,6 @@ from core.app.entities.task_entities import ( from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AppGeneratorTTSPublisher from core.ops.ops_trace_manager import TraceQueueManager -from dify_graph.model_runtime.entities.llm_entities import LLMResult as RuntimeLLMResult -from dify_graph.model_runtime.entities.message_entities import TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py index 155e6f2c73..f7e7b7e20e 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_based_generate_task_pipeline_core.py @@ -5,6 +5,9 @@ from types import SimpleNamespace from unittest.mock import Mock import pytest +from graphon.file import FileTransferMethod +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from core.app.app_config.entities import ( AppAdditionalFeatures, @@ -38,9 +41,6 @@ from core.app.entities.task_entities import ( ) from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline from core.base.tts import AudioTrunk -from dify_graph.file.enums import FileTransferMethod -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, TextPromptMessageContent from models.model import AppMode diff --git a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py index 37dd116470..31b7313066 100644 --- a/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py +++ b/api/tests/unit_tests/core/app/task_pipeline/test_easy_ui_message_end_files.py @@ -17,11 +17,11 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FileTransferMethod, FileType from sqlalchemy.orm import Session from core.app.entities.task_entities import MessageEndStreamResponse from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline -from dify_graph.file.enums import FileTransferMethod, FileType from models.model import MessageFile, UploadFile diff --git a/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py new file mode 100644 index 0000000000..29df7eea86 --- /dev/null +++ b/api/tests/unit_tests/core/app/test_easy_ui_model_config_manager.py @@ -0,0 +1,58 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from graphon.model_runtime.entities.model_entities import ModelPropertyKey + +from core.app.app_config.easy_ui_based_app.model_config.manager import ModelConfigManager +from core.app.app_config.entities import ModelConfigEntity +from models.provider_ids import ModelProviderID + + +def test_validate_and_set_defaults_reuses_single_model_assembly(): + provider_name = str(ModelProviderID("openai")) + provider_entity = SimpleNamespace(provider=provider_name) + model = SimpleNamespace(model="gpt-4o-mini", model_properties={ModelPropertyKey.MODE: "chat"}) + provider_configurations = SimpleNamespace(get_models=lambda **kwargs: [model]) + assembly = SimpleNamespace( + model_provider_factory=SimpleNamespace(get_providers=lambda: [provider_entity]), + provider_manager=SimpleNamespace(get_configurations=lambda tenant_id: provider_configurations), + ) + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "completion_params": {"stop": []}, + } + } + + with patch( + "core.app.app_config.easy_ui_based_app.model_config.manager.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + result, keys = ModelConfigManager.validate_and_set_defaults("tenant-1", config) + + assert result["model"]["provider"] == provider_name + assert result["model"]["mode"] == "chat" + assert keys == ["model"] + mock_assembly.assert_called_once_with(tenant_id="tenant-1") + + +def test_convert_keeps_model_config_shape(): + config = { + "model": { + "provider": "openai", + "name": "gpt-4o-mini", + "mode": "chat", + "completion_params": {"temperature": 0.3, "stop": ["END"]}, + } + } + + result = ModelConfigManager.convert(config) + + assert result == ModelConfigEntity( + provider="openai", + model="gpt-4o-mini", + mode="chat", + parameters={"temperature": 0.3}, + stop=["END"], + ) diff --git a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py index 0f8a846d11..dc2d82ccd6 100644 --- a/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py +++ b/api/tests/unit_tests/core/app/workflow/layers/test_persistence.py @@ -2,14 +2,14 @@ from datetime import UTC, datetime from unittest.mock import Mock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus, WorkflowType +from graphon.node_events import NodeRunResult from core.app.workflow.layers.persistence import ( PersistenceWorkflowInfo, WorkflowPersistenceLayer, _NodeRuntimeSnapshot, ) -from dify_graph.enums import WorkflowNodeExecutionStatus, WorkflowType -from dify_graph.node_events import NodeRunResult def _build_layer() -> WorkflowPersistenceLayer: @@ -58,3 +58,42 @@ def test_update_node_execution_prefers_event_finished_at(monkeypatch: pytest.Mon assert node_execution.finished_at == event_finished_at assert node_execution.elapsed_time == 2.0 + + +def test_update_node_execution_projects_start_outputs() -> None: + layer = _build_layer() + node_execution = Mock() + node_execution.id = "node-exec-2" + node_execution.node_type = BuiltinNodeTypes.START + node_execution.created_at = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC).replace(tzinfo=None) + node_execution.update_from_mapping = Mock() + + layer._node_snapshots[node_execution.id] = _NodeRuntimeSnapshot( + node_id="start", + title="Start", + predecessor_node_id=None, + iteration_id=None, + loop_id=None, + created_at=node_execution.created_at, + ) + + layer._update_node_execution( + node_execution, + NodeRunResult( + status=WorkflowNodeExecutionStatus.SUCCEEDED, + inputs={"question": "hello"}, + outputs={ + "question": "hello", + "sys.query": "hello", + "env.API_KEY": "secret", + }, + ), + WorkflowNodeExecutionStatus.SUCCEEDED, + ) + + node_execution.update_from_mapping.assert_called_once_with( + inputs={"question": "hello"}, + process_data={}, + outputs={"question": "hello"}, + metadata={}, + ) diff --git a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py index fb76f22a2a..7be9d6ac1e 100644 --- a/api/tests/unit_tests/core/app/workflow/test_file_runtime.py +++ b/api/tests/unit_tests/core/app/workflow/test_file_runtime.py @@ -1,43 +1,370 @@ -from unittest.mock import patch +from __future__ import annotations +import base64 +import hashlib +import hmac +from types import SimpleNamespace +from unittest.mock import MagicMock, patch +from urllib.parse import parse_qs, urlparse + +import pytest +from graphon.file import File, FileTransferMethod, FileType + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.file_access import DatabaseFileAccessController, FileAccessScope +from core.app.workflow import file_runtime from core.app.workflow.file_runtime import DifyWorkflowFileRuntime, bind_dify_workflow_file_runtime +from core.workflow.file_reference import build_file_reference +from models import ToolFile, UploadFile -class TestDifyWorkflowFileRuntime: - def test_runtime_properties_and_helpers(self, monkeypatch): - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "http://files") - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", "http://internal") - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "secret") - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 123) - monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") +def _build_file( + *, + transfer_method: FileTransferMethod, + reference: str | None = None, + remote_url: str | None = None, + extension: str | None = None, +) -> File: + return File( + id="file-id", + type=FileType.IMAGE, + transfer_method=transfer_method, + reference=reference, + remote_url=remote_url, + filename="diagram.png", + extension=extension, + mime_type="image/png", + size=128, + ) - runtime = DifyWorkflowFileRuntime() - assert runtime.files_url == "http://files" - assert runtime.internal_files_url == "http://internal" - assert runtime.secret_key == "secret" - assert runtime.files_access_timeout == 123 - assert runtime.multimodal_send_format == "url" +def _build_runtime() -> DifyWorkflowFileRuntime: + return DifyWorkflowFileRuntime(file_access_controller=DatabaseFileAccessController()) - with patch("core.app.workflow.file_runtime.ssrf_proxy.get") as mock_get: - mock_get.return_value = "response" - assert runtime.http_get("http://example", follow_redirects=False) == "response" - mock_get.assert_called_once_with("http://example", follow_redirects=False) - with patch("core.app.workflow.file_runtime.storage.load") as mock_load: - mock_load.return_value = b"data" - assert runtime.storage_load("path", stream=True) == b"data" - mock_load.assert_called_once_with("path", stream=True) +def test_resolve_file_url_returns_remote_url() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url="https://example.com/diagram.png", + ) - with patch("core.app.workflow.file_runtime.sign_tool_file") as mock_sign: - mock_sign.return_value = "signed" - assert runtime.sign_tool_file(tool_file_id="id", extension=".txt", for_external=False) == "signed" - mock_sign.assert_called_once_with(tool_file_id="id", extension=".txt", for_external=False) + assert runtime.resolve_file_url(file=file) == "https://example.com/diagram.png" - def test_bind_runtime_registers_instance(self): - with patch("core.app.workflow.file_runtime.set_workflow_file_runtime") as mock_set: - bind_dify_workflow_file_runtime() - mock_set.assert_called_once() - runtime = mock_set.call_args[0][0] - assert isinstance(runtime, DifyWorkflowFileRuntime) +def test_resolve_file_url_requires_file_reference() -> None: + runtime = _build_runtime() + file = SimpleNamespace(transfer_method=FileTransferMethod.LOCAL_FILE, reference=None) + + with pytest.raises(ValueError, match="Missing file reference"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_requires_extension_for_tool_files() -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=None, + ) + + with pytest.raises(ValueError, match="Missing file extension"): + runtime.resolve_file_url(file=file) + + +def test_resolve_file_url_uses_tool_signatures_for_tool_and_datasource_files( + monkeypatch: pytest.MonkeyPatch, +) -> None: + sign_tool_file = MagicMock(return_value="https://signed.example.com/file") + monkeypatch.setattr(file_runtime, "sign_tool_file", sign_tool_file) + runtime = _build_runtime() + + tool_file = _build_file( + transfer_method=FileTransferMethod.TOOL_FILE, + reference=build_file_reference(record_id="tool-file-id"), + extension=".png", + ) + datasource_file = _build_file( + transfer_method=FileTransferMethod.DATASOURCE_FILE, + reference=build_file_reference(record_id="datasource-file-id"), + extension=".png", + ) + + assert runtime.resolve_file_url(file=tool_file) == "https://signed.example.com/file" + assert runtime.resolve_file_url(file=datasource_file) == "https://signed.example.com/file" + assert sign_tool_file.call_count == 2 + + +def test_resolve_upload_file_url_signs_internal_urls_and_supports_attachments( + monkeypatch: pytest.MonkeyPatch, +) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.os.urandom", lambda _: b"\x01" * 16) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr( + "core.app.workflow.file_runtime.dify_config.INTERNAL_FILES_URL", + "https://internal.example.com", + ) + + runtime = _build_runtime() + url = runtime.resolve_upload_file_url( + upload_file_id="upload-file-id", + as_attachment=True, + for_external=False, + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload-file-id/file-preview" + assert query["as_attachment"] == ["true"] + assert query["timestamp"] == ["1700000000"] + + +def test_verify_preview_signature_validates_signature_and_expiration(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000000) + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.FILES_ACCESS_TIMEOUT", 60) + runtime = _build_runtime() + payload = "file-preview|upload-file-id|1700000000|nonce" + sign = base64.urlsafe_b64encode(hmac.new(b"unit-secret", payload.encode(), hashlib.sha256).digest()).decode() + + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is True + ) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.app.workflow.file_runtime.time.time", lambda: 1700000100) + assert ( + runtime.verify_preview_signature( + preview_kind="file", + file_id="upload-file-id", + timestamp="1700000000", + nonce="nonce", + sign=sign, + ) + is False + ) + + +def test_load_file_bytes_returns_bytes_and_rejects_non_bytes(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: b"image-bytes") + + assert runtime.load_file_bytes(file=file) == b"image-bytes" + session.get.assert_called_with(UploadFile, "upload-file-id") + + monkeypatch.setattr(file_runtime.storage, "load", lambda *args, **kwargs: "not-bytes") + with pytest.raises(ValueError, match="is not a bytes object"): + runtime.load_file_bytes(file=file) + + +def test_resolve_storage_key_ignores_encoded_reference_when_unscoped(monkeypatch: pytest.MonkeyPatch) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + session.get.return_value = SimpleNamespace(key="canonical-storage-key") + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + session.get.assert_called_once_with(UploadFile, "upload-file-id") + + +def test_resolve_storage_key_uses_canonical_record_when_scope_is_bound(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = SimpleNamespace(key="canonical-storage-key") + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + file = _build_file( + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-file-id", storage_key="tampered-storage-key"), + ) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == "canonical-storage-key" + controller.get_upload_file.assert_called_once_with(session=session, file_id="upload-file-id") + + +def test_resolve_upload_file_url_rejects_unauthorized_scoped_access(monkeypatch: pytest.MonkeyPatch) -> None: + controller = MagicMock() + controller.current_scope.return_value = FileAccessScope( + tenant_id="tenant-id", + user_id="end-user-id", + user_from=UserFrom.END_USER, + invoke_from=InvokeFrom.WEB_APP, + ) + controller.get_upload_file.return_value = None + runtime = DifyWorkflowFileRuntime(file_access_controller=controller) + session = MagicMock() + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match="Upload file upload-file-id not found"): + runtime.resolve_upload_file_url(upload_file_id="upload-file-id") + + +@pytest.mark.parametrize( + ("transfer_method", "record_id", "expected_storage_key"), + [ + (FileTransferMethod.LOCAL_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.DATASOURCE_FILE, "upload-file-id", "upload-storage-key"), + (FileTransferMethod.TOOL_FILE, "tool-file-id", "tool-storage-key"), + ], +) +def test_resolve_storage_key_loads_database_records( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + record_id: str, + expected_storage_key: str, +) -> None: + runtime = _build_runtime() + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + + def get(model_class, value): + if transfer_method in {FileTransferMethod.LOCAL_FILE, FileTransferMethod.DATASOURCE_FILE}: + assert model_class is UploadFile + return SimpleNamespace(key="upload-storage-key") + assert model_class is ToolFile + return SimpleNamespace(file_key="tool-storage-key") + + session.get.side_effect = get + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + assert runtime._resolve_storage_key(file=file) == expected_storage_key + + +@pytest.mark.parametrize( + ("transfer_method", "expected_message"), + [ + (FileTransferMethod.LOCAL_FILE, "Upload file upload-file-id not found"), + (FileTransferMethod.TOOL_FILE, "Tool file tool-file-id not found"), + ], +) +def test_resolve_storage_key_raises_when_records_are_missing( + monkeypatch: pytest.MonkeyPatch, + transfer_method: FileTransferMethod, + expected_message: str, +) -> None: + runtime = _build_runtime() + record_id = "upload-file-id" if transfer_method == FileTransferMethod.LOCAL_FILE else "tool-file-id" + file = _build_file( + transfer_method=transfer_method, + reference=build_file_reference(record_id=record_id), + extension=".png", + ) + session = MagicMock() + session.get.return_value = None + + class _SessionContext: + def __enter__(self): + return session + + def __exit__(self, exc_type, exc, tb): + return False + + monkeypatch.setattr(file_runtime.session_factory, "create_session", lambda: _SessionContext()) + + with pytest.raises(ValueError, match=expected_message): + runtime._resolve_storage_key(file=file) + + +def test_runtime_helper_wrappers_delegate_to_config_and_io(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.app.workflow.file_runtime.dify_config.MULTIMODAL_SEND_FORMAT", "url") + runtime = _build_runtime() + + assert runtime.multimodal_send_format == "url" + + with patch.object(file_runtime.ssrf_proxy, "get", return_value="response") as mock_get: + assert runtime.http_get("http://example", follow_redirects=False) == "response" + mock_get.assert_called_once_with("http://example", follow_redirects=False) + + with patch.object(file_runtime.storage, "load", return_value=b"data") as mock_load: + assert runtime.storage_load("path", stream=True) == b"data" + mock_load.assert_called_once_with("path", stream=True) + + +def test_bind_dify_workflow_file_runtime_registers_runtime(monkeypatch: pytest.MonkeyPatch) -> None: + set_runtime = MagicMock() + monkeypatch.setattr(file_runtime, "set_workflow_file_runtime", set_runtime) + + bind_dify_workflow_file_runtime() + + set_runtime.assert_called_once() + assert isinstance(set_runtime.call_args.args[0], DifyWorkflowFileRuntime) diff --git a/api/tests/unit_tests/core/app/workflow/test_node_factory.py b/api/tests/unit_tests/core/app/workflow/test_node_factory.py index 9e742507c6..8497261d45 100644 --- a/api/tests/unit_tests/core/app/workflow/test_node_factory.py +++ b/api/tests/unit_tests/core/app/workflow/test_node_factory.py @@ -1,10 +1,10 @@ from types import SimpleNamespace import pytest +from graphon.enums import BuiltinNodeTypes from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import BuiltinNodeTypes class DummyNode: @@ -131,7 +131,7 @@ class TestDifyNodeFactory: node = factory.create_node({"id": "node-1", "data": {"type": BuiltinNodeTypes.TEMPLATE_TRANSFORM}}) assert isinstance(node, DummyTemplateTransformNode) - assert "template_renderer" in node.kwargs + assert "jinja2_template_renderer" in node.kwargs def test_create_node_http_request_branch(self, monkeypatch): factory = self._factory(monkeypatch) diff --git a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py index 0565f4cfe9..a47d3db6f5 100644 --- a/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py +++ b/api/tests/unit_tests/core/app/workflow/test_observability_layer_extra.py @@ -2,8 +2,9 @@ from __future__ import annotations from types import SimpleNamespace +from graphon.enums import BuiltinNodeTypes + from core.app.workflow.layers.observability import ObservabilityLayer -from dify_graph.enums import BuiltinNodeTypes class TestObservabilityLayerExtras: diff --git a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py index 45f6a0c7a1..d8a68f6d00 100644 --- a/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py +++ b/api/tests/unit_tests/core/app/workflow/test_persistence_layer.py @@ -4,27 +4,21 @@ from datetime import UTC, datetime from types import SimpleNamespace import pytest - -from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity -from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.enums import ( +from graphon.entities import WorkflowNodeExecution +from graphon.entities.pause_reason import SchedulingPause +from graphon.enums import ( BuiltinNodeTypes, - SystemVariableKey, WorkflowExecutionStatus, WorkflowNodeExecutionStatus, WorkflowType, ) -from dify_graph.graph_events.graph import ( +from graphon.graph_events import ( GraphRunAbortedEvent, GraphRunFailedEvent, GraphRunPartialSucceededEvent, GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, -) -from dify_graph.graph_events.node import ( NodeRunExceptionEvent, NodeRunFailedEvent, NodeRunPauseRequestedEvent, @@ -32,9 +26,12 @@ from dify_graph.graph_events.node import ( NodeRunStartedEvent, NodeRunSucceededEvent, ) -from dify_graph.node_events import NodeRunResult -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.node_events import NodeRunResult +from graphon.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool + +from core.app.entities.app_invoke_entities import WorkflowAppGenerateEntity +from core.app.workflow.layers.persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer +from core.workflow.system_variables import SystemVariableKey, build_system_variables class _RepoRecorder: @@ -54,13 +51,16 @@ def _naive_utc_now() -> datetime: def _make_layer( - system_variable: SystemVariable | None = None, + system_variables: list | None = None, *, extras: dict | None = None, trace_manager: object | None = None, ): - system_variable = system_variable or SystemVariable(workflow_execution_id="run-id", conversation_id="conv-id") - runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variable), start_at=0.0) + system_variables = system_variables or build_system_variables( + workflow_execution_id="run-id", + conversation_id="conv-id", + ) + runtime_state = GraphRuntimeState(variable_pool=VariablePool(system_variables=system_variables), start_at=0.0) read_only_state = ReadOnlyGraphRuntimeStateWrapper(runtime_state) application_generate_entity = WorkflowAppGenerateEntity.model_construct( @@ -115,8 +115,7 @@ class TestWorkflowPersistenceLayer: assert layer._node_sequence == 0 def test_get_execution_id_requires_system_variable(self): - system_variable = SystemVariable(workflow_execution_id=None) - layer, _, _, _ = _make_layer(system_variable) + layer, _, _, _ = _make_layer(build_system_variables()) with pytest.raises(ValueError, match="workflow_execution_id must be provided"): layer._get_execution_id() diff --git a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py index 3759b6aa37..5ff9774b52 100644 --- a/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py +++ b/api/tests/unit_tests/core/base/test_app_generator_tts_publisher.py @@ -28,10 +28,7 @@ def mock_model_instance(mocker): def mock_model_manager(mocker, mock_model_instance): manager = mocker.MagicMock() manager.get_default_model_instance.return_value = mock_model_instance - mocker.patch( - "core.base.tts.app_generator_tts_publisher.ModelManager", - return_value=manager, - ) + mocker.patch("core.base.tts.app_generator_tts_publisher.ModelManager.for_tenant", return_value=manager) return manager @@ -64,16 +61,14 @@ class TestInvoiceTTS: [None, "", " "], ) def test_invoice_tts_empty_or_none_returns_none(self, text, mock_model_instance): - result = _invoice_tts(text, mock_model_instance, "tenant", "voice1") + result = _invoice_tts(text, mock_model_instance, "voice1") assert result is None mock_model_instance.invoke_tts.assert_not_called() def test_invoice_tts_valid_text(self, mock_model_instance): - result = _invoice_tts(" hello ", mock_model_instance, "tenant", "voice1") + result = _invoice_tts(" hello ", mock_model_instance, "voice1") mock_model_instance.invoke_tts.assert_called_once_with( content_text="hello", - user="responding_tts", - tenant_id="tenant", voice="voice1", ) assert result == [b"audio1", b"audio2"] @@ -306,14 +301,15 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() - from core.app.entities.queue_entities import QueueAgentMessageEvent - from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta - from dify_graph.model_runtime.entities.message_entities import ( + from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, TextPromptMessageContent, ) + from core.app.entities.queue_entities import QueueAgentMessageEvent + chunk = LLMResultChunk( model="model", delta=LLMResultChunkDelta( @@ -341,9 +337,10 @@ class TestAppGeneratorTTSPublisher: publisher = AppGeneratorTTSPublisher("tenant", "voice1") publisher.executor = MagicMock() + from graphon.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta + from graphon.model_runtime.entities.message_entities import AssistantPromptMessage + from core.app.entities.queue_entities import QueueAgentMessageEvent - from dify_graph.model_runtime.entities.llm_entities import LLMResultChunk, LLMResultChunkDelta - from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage chunk = LLMResultChunk( model="model", diff --git a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py index b37c4c57a1..8e5670e9be 100644 --- a/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py +++ b/api/tests/unit_tests/core/callback_handler/test_index_tool_callback_handler.py @@ -114,13 +114,9 @@ class TestOnToolEnd: document = mocker.Mock() document.metadata = {"document_id": "doc-1", "doc_id": "node-1"} - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_non_parent_child_index(self, handler, mocker): @@ -138,13 +134,9 @@ class TestOnToolEnd: "dataset_id": "dataset-1", } - mock_query = mocker.Mock() - mock_db.session.query.return_value = mock_query - mock_query.where.return_value = mock_query - handler.on_tool_end([document]) - mock_query.update.assert_called_once() + mock_db.session.execute.assert_called_once() mock_db.session.commit.assert_called_once() def test_on_tool_end_empty_documents(self, handler): diff --git a/api/tests/unit_tests/core/datasource/test_datasource_manager.py b/api/tests/unit_tests/core/datasource/test_datasource_manager.py index d5eeae912c..b0c72ee42f 100644 --- a/api/tests/unit_tests/core/datasource/test_datasource_manager.py +++ b/api/tests/unit_tests/core/datasource/test_datasource_manager.py @@ -2,15 +2,15 @@ import types from collections.abc import Generator import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent from contexts.wrapper import RecyclableContextVar from core.datasource.datasource_manager import DatasourceManager from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType from core.datasource.errors import DatasourceProviderNotFoundError -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent +from core.workflow.file_reference import parse_file_reference def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, None]: @@ -428,11 +428,8 @@ def test_stream_node_events_builds_file_and_variables_from_messages(mocker): return fake_tool_file mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session()) - mocker.patch( - "core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE - ) + mocker.patch("core.datasource.datasource_manager.get_file_type_by_mime_type", return_value=FileType.IMAGE) built = File( - tenant_id="t1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tool_file_1", @@ -533,7 +530,6 @@ def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(moc mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored")) file_in = File( - tenant_id="t1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="tf", @@ -664,6 +660,8 @@ def test_get_upload_file_by_id_builds_file(mocker): f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1") assert f.related_id == "fid" assert f.extension == ".txt" + assert parse_file_reference(f.reference).storage_key is None + assert f.storage_key == "k" def test_get_upload_file_by_id_raises_when_missing(mocker): diff --git a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py index 43f582feb7..fbaf6d497d 100644 --- a/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/datasource/utils/test_message_transformer.py @@ -1,11 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType from core.datasource.entities.datasource_entities import DatasourceMessage from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer -from dify_graph.file import File -from dify_graph.file.enums import FileTransferMethod, FileType from models.tools import ToolFile diff --git a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py index 2e4f6d34fb..ff9fd0d8f3 100644 --- a/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py +++ b/api/tests/unit_tests/core/entities/test_entities_execution_extra_content.py @@ -1,11 +1,12 @@ +from graphon.nodes.human_input.entities import FormInput, UserAction +from graphon.nodes.human_input.enums import FormInputType + from core.entities.execution_extra_content import ( ExecutionExtraContentDomainModel, HumanInputContent, HumanInputFormDefinition, HumanInputFormSubmissionData, ) -from dify_graph.nodes.human_input.entities import FormInput, UserAction -from dify_graph.nodes.human_input.enums import FormInputType from models.execution_extra_content import ExecutionContentType diff --git a/api/tests/unit_tests/core/entities/test_entities_model_entities.py b/api/tests/unit_tests/core/entities/test_entities_model_entities.py index 7a3d5e84ed..2acd278a31 100644 --- a/api/tests/unit_tests/core/entities/test_entities_model_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_model_entities.py @@ -8,6 +8,9 @@ drive provider mapping behavior. """ import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity from core.entities.model_entities import ( DefaultModelEntity, @@ -16,9 +19,6 @@ from core.entities.model_entities import ( ProviderModelWithStatusEntity, SimpleModelProviderEntity, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity def _build_model_with_status(status: ModelStatus) -> ProviderModelWithStatusEntity: diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py index 95d58757f1..8cf0409c4c 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_configuration.py @@ -6,6 +6,17 @@ from typing import Any from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) from constants import HIDDEN_VALUE from core.entities.model_entities import ModelStatus @@ -24,17 +35,6 @@ from core.entities.provider_entities import ( SystemConfiguration, SystemConfigurationStatus, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FieldModelSchema, - FormType, - ModelCredentialSchema, - ProviderCredentialSchema, - ProviderEntity, -) from models.enums import CredentialSourceType from models.provider import ProviderType from models.provider_ids import ModelProviderID @@ -350,7 +350,7 @@ def test_validate_provider_credentials_handles_hidden_secret_value() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "restored-key", "region": "us"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="restored-key"): with patch( "core.entities.provider_configuration.encrypter.encrypt_token", @@ -380,7 +380,9 @@ def test_validate_provider_credentials_opens_session_when_not_passed() -> None: with patch("core.entities.provider_configuration.db") as mock_db: mock_db.engine = Mock() mock_session_cls.return_value.__enter__.return_value = mock_session - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_provider_credentials(credentials={"region": "us"}) assert validated == {"region": "us"} @@ -434,12 +436,16 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: mock_factory.get_model_type_instance.return_value = mock_model_type_instance mock_factory.get_model_schema.return_value = mock_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", + return_value=mock_factory, + ) as mock_factory_builder: model_type_instance = configuration.get_model_type_instance(ModelType.LLM) model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) assert model_type_instance is mock_model_type_instance assert model_schema is mock_schema + assert mock_factory_builder.call_count == 2 mock_factory.get_model_type_instance.assert_called_once_with(provider="openai", model_type=ModelType.LLM) mock_factory.get_model_schema.assert_called_once_with( provider="openai", @@ -449,6 +455,33 @@ def test_get_model_type_instance_and_schema_delegate_to_factory() -> None: ) +def test_get_model_type_instance_and_schema_reuse_bound_runtime_factory() -> None: + configuration = _build_provider_configuration() + bound_runtime = Mock() + configuration.bind_model_runtime(bound_runtime) + + mock_factory = Mock() + mock_model_type_instance = Mock() + mock_schema = _build_ai_model("gpt-4o") + mock_factory.get_model_type_instance.return_value = mock_model_type_instance + mock_factory.get_model_schema.return_value = mock_schema + + with ( + patch( + "core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory + ) as mock_factory_cls, + patch("core.entities.provider_configuration.create_plugin_model_provider_factory") as mock_factory_builder, + ): + model_type_instance = configuration.get_model_type_instance(ModelType.LLM) + model_schema = configuration.get_model_schema(ModelType.LLM, "gpt-4o", {"api_key": "x"}) + + assert model_type_instance is mock_model_type_instance + assert model_schema is mock_schema + assert mock_factory_cls.call_count == 2 + mock_factory_cls.assert_called_with(model_runtime=bound_runtime) + mock_factory_builder.assert_not_called() + + def test_get_provider_model_returns_none_when_model_not_found() -> None: configuration = _build_provider_configuration() fake_model = SimpleNamespace(model="other-model") @@ -475,7 +508,7 @@ def test_get_provider_models_system_deduplicates_sorts_and_filters_active() -> N mock_factory = Mock() mock_factory.get_provider_schema.return_value = provider_schema - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): all_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=False) active_models = configuration.get_provider_models(model_type=ModelType.LLM, only_active=True) @@ -689,7 +722,7 @@ def test_validate_provider_credentials_handles_invalid_original_json() -> None: mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "new-key"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-key"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1034,7 +1067,7 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.decrypt_token", return_value="raw"): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( @@ -1050,7 +1083,9 @@ def test_validate_custom_model_credentials_supports_hidden_reuse_and_sessionless mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"region": "us"} with _patched_session(session): - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch( + "core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory + ): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, model="gpt-4o", @@ -1540,7 +1575,7 @@ def test_validate_provider_credentials_uses_empty_original_when_record_missing() mock_factory = Mock() mock_factory.provider_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_provider_credentials( credentials={"openai_api_key": HIDDEN_VALUE}, @@ -1662,7 +1697,7 @@ def test_validate_custom_model_credentials_handles_invalid_original_json() -> No mock_factory = Mock() mock_factory.model_credentials_validate.return_value = {"openai_api_key": "raw"} - with patch("core.entities.provider_configuration.ModelProviderFactory", return_value=mock_factory): + with patch("core.entities.provider_configuration.create_plugin_model_provider_factory", return_value=mock_factory): with patch("core.entities.provider_configuration.encrypter.encrypt_token", return_value="enc-new"): validated = configuration.validate_custom_model_credentials( model_type=ModelType.LLM, diff --git a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py index c5bfd05a1e..8685d16283 100644 --- a/api/tests/unit_tests/core/entities/test_entities_provider_entities.py +++ b/api/tests/unit_tests/core/entities/test_entities_provider_entities.py @@ -1,4 +1,5 @@ import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.entities.parameter_entities import AppSelectorScope from core.entities.provider_entities import ( @@ -8,7 +9,6 @@ from core.entities.provider_entities import ( ProviderQuotaType, ) from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType def test_provider_quota_type_value_of_returns_enum_member() -> None: diff --git a/api/tests/unit_tests/core/file/test_models.py b/api/tests/unit_tests/core/file/test_models.py index deebf41320..bb6e40e224 100644 --- a/api/tests/unit_tests/core/file/test_models.py +++ b/api/tests/unit_tests/core/file/test_models.py @@ -1,4 +1,4 @@ -from dify_graph.file import File, FileTransferMethod, FileType +from graphon.file import File, FileTransferMethod, FileType def test_file(): @@ -15,18 +15,17 @@ def test_file(): storage_key="test-storage-key", url="https://example.com/image.png", ) - assert file.tenant_id == "test-tenant-id" assert file.type == FileType.IMAGE assert file.transfer_method == FileTransferMethod.TOOL_FILE assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" assert file.filename == "image.png" assert file.extension == ".png" assert file.mime_type == "image/png" assert file.size == 67 -def test_file_model_validate_with_legacy_fields(): - """Test `File` model can handle data containing compatibility fields.""" +def test_file_model_validate_accepts_legacy_tenant_id(): data = { "id": "test-file", "tenant_id": "test-tenant-id", @@ -45,10 +44,8 @@ def test_file_model_validate_with_legacy_fields(): "datasource_file_id": "datasource-file-789", } - # Should be able to create `File` object without raising an exception file = File.model_validate(data) - # The File object does not have tool_file_id, upload_file_id, or datasource_file_id as attributes. - # Instead, check it does not expose unrecognized legacy fields (should raise on getattr). - for legacy_field in ("tool_file_id", "upload_file_id", "datasource_file_id"): - assert not hasattr(file, legacy_field) + assert file.related_id == "test-related-id" + assert file.storage_key == "test-storage-key" + assert "tenant_id" not in file.model_dump() diff --git a/api/tests/unit_tests/core/helper/test_encrypter.py b/api/tests/unit_tests/core/helper/test_encrypter.py index 5890009742..f3ef7fccd0 100644 --- a/api/tests/unit_tests/core/helper/test_encrypter.py +++ b/api/tests/unit_tests/core/helper/test_encrypter.py @@ -38,13 +38,13 @@ class TestObfuscatedToken: class TestEncryptToken: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_successful_encryption(self, mock_encrypt, mock_query): """Test successful token encryption""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_data" result = encrypt_token("tenant-123", "test_token") @@ -52,10 +52,10 @@ class TestEncryptToken: assert result == base64.b64encode(b"encrypted_data").decode() mock_encrypt.assert_called_with("test_token", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") def test_tenant_not_found(self, mock_query): """Test error when tenant doesn't exist""" - mock_query.return_value.where.return_value.first.return_value = None + mock_query.return_value = None with pytest.raises(ValueError) as exc_info: encrypt_token("invalid-tenant", "test_token") @@ -119,7 +119,7 @@ class TestGetDecryptDecoding: class TestEncryptDecryptIntegration: - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") @patch("libs.rsa.decrypt") def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query): @@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration: # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Setup mock encryption/decryption original_token = "test_token_123" @@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration: class TestSecurity: """Critical security tests for encryption system""" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_cross_tenant_isolation(self, mock_encrypt, mock_query): """Ensure tokens encrypted for one tenant cannot be used by another""" # Setup mock tenant mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "tenant1_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_for_tenant1" # Encrypt token for tenant1 @@ -181,12 +181,12 @@ class TestSecurity: with pytest.raises(Exception, match="Decryption error"): decrypt_token("tenant-123", tampered) - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_encryption_randomness(self, mock_encrypt, mock_query): """Ensure same plaintext produces different ciphertext""" mock_tenant = MagicMock(encrypt_public_key="key") - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # Different outputs for same input mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"] @@ -205,13 +205,13 @@ class TestEdgeCases: # Test empty string (which is a valid str type) assert obfuscated_token("") == "" - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query): """Test encryption of empty token""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_empty" result = encrypt_token("tenant-123", "") @@ -219,13 +219,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_empty").decode() mock_encrypt.assert_called_with("", "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query): """Test tokens containing special/unicode characters""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant mock_encrypt.return_value = b"encrypted_special" # Test various special characters @@ -242,13 +242,13 @@ class TestEdgeCases: assert result == base64.b64encode(b"encrypted_special").decode() mock_encrypt.assert_called_with(token, "mock_public_key") - @patch("models.engine.db.session.query") + @patch("extensions.ext_database.db.session.get") @patch("libs.rsa.encrypt") def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query): """Test behavior when token exceeds RSA encryption limits""" mock_tenant = MagicMock() mock_tenant.encrypt_public_key = "mock_public_key" - mock_query.return_value.where.return_value.first.return_value = mock_tenant + mock_query.return_value = mock_tenant # RSA 2048-bit can only encrypt ~245 bytes # The actual limit depends on padding scheme diff --git a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py index 46c9dc6f9c..b45f6fd9a7 100644 --- a/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py +++ b/api/tests/unit_tests/core/llm_generator/output_parser/test_structured_output.py @@ -2,6 +2,20 @@ import json from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import ( + LLMResult, + LLMResultChunk, + LLMResultChunkDelta, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + TextPromptMessageContent, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType from core.llm_generator.output_parser.errors import OutputParserError from core.llm_generator.output_parser.structured_output import ( @@ -16,20 +30,6 @@ from core.llm_generator.output_parser.structured_output import ( remove_additional_properties, ) from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities.llm_entities import ( - LLMResult, - LLMResultChunk, - LLMResultChunkDelta, - LLMResultWithStructuredOutput, - LLMUsage, -) -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - TextPromptMessageContent, - UserPromptMessage, -) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ParameterRule, ParameterType class TestStructuredOutput: diff --git a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py index 5b7640696f..62e714deb6 100644 --- a/api/tests/unit_tests/core/llm_generator/test_llm_generator.py +++ b/api/tests/unit_tests/core/llm_generator/test_llm_generator.py @@ -2,18 +2,18 @@ import json from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMMode, LLMResult +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.app.app_config.entities import ModelConfig from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload from core.llm_generator.llm_generator import LLMGenerator -from dify_graph.model_runtime.entities.llm_entities import LLMMode, LLMResult -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError class TestLLMGenerator: @pytest.fixture def mock_model_instance(self): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_default_model_instance.return_value = instance mock_manager.return_value.get_model_instance.return_value = instance @@ -98,7 +98,7 @@ class TestLLMGenerator: assert questions[0] == "Question 1?" def test_generate_suggested_questions_after_answer_auth_error(self, mock_model_instance): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: mock_manager.return_value.get_default_model_instance.side_effect = InvokeAuthorizationError("Auth failed") questions = LLMGenerator.generate_suggested_questions_after_answer("tenant_id", "histories") assert questions == [] @@ -314,8 +314,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None # Mock __instruction_modify_common call via invoke_llm mock_response = MagicMock() @@ -328,12 +328,12 @@ class TestLLMGenerator: assert result == {"modified": "prompt"} def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: last_run = MagicMock() last_run.query = "q" last_run.answer = "a" last_run.error = "e" - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run + mock_scalar.return_value = last_run mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"modified": "prompt"}' @@ -483,8 +483,8 @@ class TestLLMGenerator: def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity): # Testing placeholders replacement via instruction_modify_legacy for convenience - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = '{"ok": true}' @@ -504,8 +504,8 @@ class TestLLMGenerator: assert "current_val" in user_msg_dict["instruction"] def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No braces here" mock_model_instance.invoke_llm.return_value = mock_response @@ -516,8 +516,8 @@ class TestLLMGenerator: assert "Could not find a valid JSON object" in result["error"] def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "[1, 2, 3]" mock_model_instance.invoke_llm.return_value = mock_response @@ -528,7 +528,7 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_other_node_type(self, mock_model_instance, model_config_entity): - with patch("core.llm_generator.llm_generator.ModelManager") as mock_manager: + with patch("core.llm_generator.llm_generator.ModelManager.for_tenant") as mock_manager: instance = MagicMock() mock_manager.return_value.get_model_instance.return_value = instance mock_response = MagicMock() @@ -556,8 +556,8 @@ class TestLLMGenerator: ) def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed") result = LLMGenerator.instruction_modify_legacy( @@ -566,8 +566,8 @@ class TestLLMGenerator: assert "Failed to generate code" in result["error"] def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_model_instance.invoke_llm.side_effect = Exception("Random error") result = LLMGenerator.instruction_modify_legacy( @@ -576,8 +576,8 @@ class TestLLMGenerator: assert "An unexpected error occurred" in result["error"] def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity): - with patch("extensions.ext_database.db.session.query") as mock_query: - mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None + with patch("extensions.ext_database.db.session.scalar") as mock_scalar: + mock_scalar.return_value = None mock_response = MagicMock() mock_response.message.get_text_content.return_value = "No JSON here" diff --git a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py index f982765b1a..313d18c695 100644 --- a/api/tests/unit_tests/core/mcp/server/test_streamable_http.py +++ b/api/tests/unit_tests/core/mcp/server/test_streamable_http.py @@ -3,6 +3,7 @@ from unittest.mock import Mock, patch import jsonschema import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.app.features.rate_limiting.rate_limit import RateLimitGenerator from core.mcp import types @@ -18,7 +19,6 @@ from core.mcp.server.streamable_http import ( prepare_tool_arguments, process_mapping_response, ) -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType from models.model import App, AppMCPServer, AppMode, EndUser diff --git a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py index 5ecfe01808..9a5fb319d7 100644 --- a/api/tests/unit_tests/core/memory/test_token_buffer_memory.py +++ b/api/tests/unit_tests/core/memory/test_token_buffer_memory.py @@ -4,15 +4,15 @@ from unittest.mock import MagicMock, patch from uuid import uuid4 import pytest - -from core.memory.token_buffer_memory import TokenBufferMemory -from dify_graph.model_runtime.entities import ( +from graphon.model_runtime.entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent, UserPromptMessage, ) + +from core.memory.token_buffer_memory import TokenBufferMemory from models.model import AppMode # --------------------------------------------------------------------------- diff --git a/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py new file mode 100644 index 0000000000..6a672fdfd5 --- /dev/null +++ b/api/tests/unit_tests/core/model_runtime/test_model_provider_factory.py @@ -0,0 +1,419 @@ +from unittest.mock import Mock + +import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FieldModelSchema, + FormType, + ModelCredentialSchema, + ProviderCredentialSchema, + ProviderEntity, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +from graphon.model_runtime.model_providers.__base.moderation_model import ModerationModel +from graphon.model_runtime.model_providers.__base.rerank_model import RerankModel +from graphon.model_runtime.model_providers.__base.speech2text_model import Speech2TextModel +from graphon.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel +from graphon.model_runtime.model_providers.__base.tts_model import TTSModel +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory + + +def _build_model(model: str, model_type: ModelType) -> AIModelEntity: + return AIModelEntity( + model=model, + label=I18nObject(en_US=model), + model_type=model_type, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +def _build_provider( + *, + provider: str, + provider_name: str, + supported_model_types: list[ModelType], + models: list[AIModelEntity] | None = None, + provider_credential_schema: ProviderCredentialSchema | None = None, + model_credential_schema: ModelCredentialSchema | None = None, +) -> ProviderEntity: + return ProviderEntity( + provider=provider, + provider_name=provider_name, + label=I18nObject(en_US=provider_name or provider), + supported_model_types=supported_model_types, + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + models=models or [], + provider_credential_schema=provider_credential_schema, + model_credential_schema=model_credential_schema, + ) + + +class _FakeModelRuntime: + def __init__(self, providers: list[ProviderEntity]) -> None: + self._providers = providers + self.validate_provider_credentials = Mock() + self.validate_model_credentials = Mock() + self.get_model_schema = Mock() + self.get_provider_icon = Mock() + + def fetch_model_providers(self) -> list[ProviderEntity]: + return self._providers + + +def test_model_provider_factory_resolves_runtime_provider_name() -> None: + provider = ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_resolves_canonical_short_name_independent_of_provider_order() -> None: + providers = [ + ProviderEntity( + provider="acme/openai/openai", + provider_name="", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ProviderEntity( + provider="langgenius/openai/openai", + provider_name="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[ModelType.LLM], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + provider_schema = factory.get_model_provider("openai") + + assert provider_schema.provider == "langgenius/openai/openai" + assert provider_schema.provider_name == "openai" + + +def test_model_provider_factory_requires_runtime() -> None: + with pytest.raises(ValueError, match="model_runtime is required"): + ModelProviderFactory(model_runtime=None) # type: ignore[arg-type] + + +def test_model_provider_factory_get_providers_returns_runtime_providers() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + result = factory.get_providers() + + assert list(result) == providers + assert result is not providers + + +def test_model_provider_factory_get_provider_schema_delegates_to_provider_lookup() -> None: + provider = _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime([provider])) + + result = factory.get_provider_schema("openai") + + assert result is provider + + +def test_model_provider_factory_raises_for_unknown_provider() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Invalid provider: anthropic"): + factory.get_model_provider("anthropic") + + +def test_model_provider_factory_get_models_filters_provider_and_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ), + _build_provider( + provider="langgenius/cohere/cohere", + provider_name="cohere", + supported_model_types=[ModelType.RERANK], + models=[_build_model("rerank-v3", ModelType.RERANK)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai", model_type=ModelType.LLM) + + assert len(results) == 1 + assert results[0].provider == "langgenius/openai/openai" + assert [model.model for model in results[0].models] == ["gpt-4o-mini"] + + +def test_model_provider_factory_get_models_skips_providers_without_requested_model_type() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + models=[_build_model("gpt-4o-mini", ModelType.LLM)], + ), + _build_provider( + provider="langgenius/elevenlabs/elevenlabs", + provider_name="elevenlabs", + supported_model_types=[ModelType.TTS], + models=[_build_model("eleven_multilingual_v2", ModelType.TTS)], + ), + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(model_type=ModelType.TTS) + + assert len(results) == 1 + assert results[0].provider == "langgenius/elevenlabs/elevenlabs" + assert [model.model for model in results[0].models] == ["eleven_multilingual_v2"] + + +def test_model_provider_factory_get_models_without_model_type_keeps_all_provider_models() -> None: + providers = [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM, ModelType.TTS], + models=[_build_model("gpt-4o-mini", ModelType.LLM), _build_model("tts-1", ModelType.TTS)], + ) + ] + factory = ModelProviderFactory(model_runtime=_FakeModelRuntime(providers)) + + results = factory.get_models(provider="openai") + + assert len(results) == 1 + assert [model.model for model in results[0].models] == ["gpt-4o-mini", "tts-1"] + + +def test_model_provider_factory_validates_provider_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + provider_credential_schema=ProviderCredentialSchema( + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ] + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.provider_credentials_validate( + provider="openai", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_provider_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_provider_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have provider_credential_schema"): + factory.provider_credentials_validate(provider="openai", credentials={"api_key": "secret"}) + + +def test_model_provider_factory_validates_model_credentials() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + model_credential_schema=ModelCredentialSchema( + model=FieldModelSchema(label=I18nObject(en_US="Model")), + credential_form_schemas=[ + CredentialFormSchema( + variable="api_key", + label=I18nObject(en_US="API key"), + type=FormType.SECRET_INPUT, + required=True, + ) + ], + ), + ) + ] + ) + factory = ModelProviderFactory(model_runtime=runtime) + + filtered = factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret", "ignored": "value"}, + ) + + assert filtered == {"api_key": "secret"} + runtime.validate_model_credentials.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_model_credentials_validate_requires_schema() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Provider openai does not have model_credential_schema"): + factory.model_credentials_validate( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + +def test_model_provider_factory_get_model_schema_and_icon_use_canonical_provider() -> None: + runtime = _FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + runtime.get_model_schema.return_value = "schema" + runtime.get_provider_icon.return_value = (b"icon", "image/png") + factory = ModelProviderFactory(model_runtime=runtime) + + assert ( + factory.get_model_schema( + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials=None, + ) + == "schema" + ) + assert factory.get_provider_icon("openai", "icon_small", "en_US") == (b"icon", "image/png") + runtime.get_model_schema.assert_called_once_with( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + runtime.get_provider_icon.assert_called_once_with( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + +@pytest.mark.parametrize( + ("model_type", "expected_type"), + [ + (ModelType.LLM, LargeLanguageModel), + (ModelType.TEXT_EMBEDDING, TextEmbeddingModel), + (ModelType.RERANK, RerankModel), + (ModelType.SPEECH2TEXT, Speech2TextModel), + (ModelType.MODERATION, ModerationModel), + (ModelType.TTS, TTSModel), + ], +) +def test_model_provider_factory_builds_model_type_instances( + model_type: ModelType, + expected_type: type[object], +) -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[model_type], + ) + ] + ) + ) + + instance = factory.get_model_type_instance("openai", model_type) + + assert isinstance(instance, expected_type) + + +def test_model_provider_factory_rejects_unsupported_model_type() -> None: + factory = ModelProviderFactory( + model_runtime=_FakeModelRuntime( + [ + _build_provider( + provider="langgenius/openai/openai", + provider_name="openai", + supported_model_types=[ModelType.LLM], + ) + ] + ) + ) + + with pytest.raises(ValueError, match="Unsupported model type: unsupported"): + factory.get_model_type_instance("openai", "unsupported") # type: ignore[arg-type] diff --git a/api/tests/unit_tests/core/moderation/test_content_moderation.py b/api/tests/unit_tests/core/moderation/test_content_moderation.py index e61cde22e7..3a97ad5c5d 100644 --- a/api/tests/unit_tests/core/moderation/test_content_moderation.py +++ b/api/tests/unit_tests/core/moderation/test_content_moderation.py @@ -324,7 +324,7 @@ class TestOpenAIModeration: with pytest.raises(ValueError, match="At least one of inputs_config or outputs_config must be enabled"): OpenAIModeration.validate_config("test-tenant", config) - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API returns no violations.""" # Mock the model manager and instance @@ -341,7 +341,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test input moderation when OpenAI API detects violations.""" # Mock the model manager to return violation @@ -358,7 +358,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Content flagged by OpenAI moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_query_included(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test that query is included in moderation check with special key.""" mock_instance = MagicMock() @@ -385,7 +385,7 @@ class TestOpenAIModeration: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_inputs_disabled(self, mock_model_manager: Mock): """Test input moderation when inputs_config is disabled.""" config = { @@ -400,7 +400,7 @@ class TestOpenAIModeration: # Should not call the API when disabled mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_no_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API returns no violations.""" mock_instance = MagicMock() @@ -414,7 +414,7 @@ class TestOpenAIModeration: assert result.action == ModerationAction.DIRECT_OUTPUT assert result.preset_response == "Response blocked by moderation." - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_with_violation(self, mock_model_manager: Mock, openai_moderation: OpenAIModeration): """Test output moderation when OpenAI API detects violations.""" mock_instance = MagicMock() @@ -427,7 +427,7 @@ class TestOpenAIModeration: assert result.flagged is True assert result.action == ModerationAction.DIRECT_OUTPUT - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_moderation_for_outputs_disabled(self, mock_model_manager: Mock): """Test output moderation when outputs_config is disabled.""" config = { @@ -441,7 +441,7 @@ class TestOpenAIModeration: assert result.flagged is False mock_model_manager.assert_not_called() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_model_manager_called_with_correct_params( self, mock_model_manager: Mock, openai_moderation: OpenAIModeration ): @@ -629,7 +629,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "Custom output blocked message" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_inputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI input violations.""" mock_instance = MagicMock() @@ -650,7 +650,7 @@ class TestPresetManagement: assert result.flagged is True assert result.preset_response == "OpenAI input blocked" - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_preset_response_in_outputs(self, mock_model_manager: Mock): """Test preset response is properly returned for OpenAI output violations.""" mock_instance = MagicMock() @@ -989,7 +989,7 @@ class TestOpenAIModerationAdvanced: - Performance considerations """ - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_timeout_handling(self, mock_model_manager: Mock): """ Test graceful handling of OpenAI API timeouts. @@ -1012,7 +1012,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(TimeoutError): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_api_rate_limit_handling(self, mock_model_manager: Mock): """ Test handling of OpenAI API rate limit errors. @@ -1035,7 +1035,7 @@ class TestOpenAIModerationAdvanced: with pytest.raises(Exception, match="Rate limit exceeded"): moderation.moderation_for_inputs({"text": "test"}, "") - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_with_multiple_input_fields(self, mock_model_manager: Mock): """ Test OpenAI moderation with multiple input fields. @@ -1079,7 +1079,7 @@ class TestOpenAIModerationAdvanced: assert "u" in moderated_text assert "e" in moderated_text - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_empty_text_handling(self, mock_model_manager: Mock): """ Test OpenAI moderation with empty text inputs. @@ -1103,7 +1103,7 @@ class TestOpenAIModerationAdvanced: assert result.flagged is False mock_instance.invoke_moderation.assert_called_once() - @patch("core.moderation.openai_moderation.openai_moderation.ModelManager", autospec=True) + @patch("core.moderation.openai_moderation.openai_moderation.ModelManager.for_tenant", autospec=True) def test_openai_model_instance_fetched_on_each_call(self, mock_model_manager: Mock): """ Test that ModelManager fetches a fresh model instance on each call. diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py index dfd61acfa7..62d631a754 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace.py @@ -5,6 +5,8 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from opentelemetry.trace import Link, SpanContext, SpanKind, Status, StatusCode, TraceFlags import core.ops.aliyun_trace.aliyun_trace as aliyun_trace_module @@ -34,8 +36,6 @@ from core.ops.entities.trace_entity import ( ToolTraceInfo, WorkflowTraceInfo, ) -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey class RecordingTraceClient: @@ -396,14 +396,14 @@ def test_get_workflow_node_executions_builds_repo_and_fetches( monkeypatch.setattr(aliyun_trace_module, "db", SimpleNamespace(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = ["node1"] + repo.get_by_workflow_execution.return_value = ["node1"] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr(aliyun_trace_module, "DifyCoreRepositoryFactory", mock_factory) result = trace_instance.get_workflow_node_executions(trace_info) assert result == ["node1"] - repo.get_by_workflow_run.assert_called_once_with(workflow_run_id=trace_info.workflow_run_id) + repo.get_by_workflow_execution.assert_called_once_with(workflow_execution_id=trace_info.workflow_run_id) def test_build_workflow_node_span_routes_llm_type(trace_instance: AliyunDataTrace, monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py index 763fc90710..2d2be12f05 100644 --- a/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py +++ b/api/tests/unit_tests/core/ops/aliyun_trace/test_aliyun_trace_utils.py @@ -1,6 +1,8 @@ import json from unittest.mock import MagicMock +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionStatus from opentelemetry.trace import Link, StatusCode from core.ops.aliyun_trace.entities.semconv import ( @@ -24,8 +26,6 @@ from core.ops.aliyun_trace.utils import ( serialize_json_data, ) from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionStatus from models import EndUser @@ -45,11 +45,8 @@ def test_get_user_id_from_message_data_with_end_user(monkeypatch): end_user_data = MagicMock(spec=EndUser) end_user_data.session_id = "session_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = end_user_data - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = end_user_data from core.ops.aliyun_trace.utils import db @@ -63,11 +60,8 @@ def test_get_user_id_from_message_data_end_user_not_found(monkeypatch): message_data.from_account_id = "account_id" message_data.from_end_user_id = "end_user_id" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = None - mock_session = MagicMock() - mock_session.query.return_value = mock_query + mock_session.get.return_value = None from core.ops.aliyun_trace.utils import db diff --git a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py index 1cee2f5b68..4ce9e22fd7 100644 --- a/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/arize_phoenix_trace/test_arize_phoenix_trace.py @@ -254,7 +254,7 @@ def test_workflow_trace_full(mock_db, mock_repo_factory, mock_sessionmaker, trac node1.id = "n1" node1.error = None - repo.get_by_workflow_run.return_value = [node1] + repo.get_by_workflow_execution.return_value = [node1] with patch.object(trace_instance, "get_service_account_with_tenant"): trace_instance.workflow_trace(info) diff --git a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py index 0ff135562c..97f7a16327 100644 --- a/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py +++ b/api/tests/unit_tests/core/ops/langfuse_trace/test_langfuse_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import LangfuseConfig from core.ops.entities.trace_entity import ( @@ -25,7 +26,6 @@ from core.ops.langfuse_trace.entities.langfuse_trace_entity import ( UnitEnum, ) from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace -from dify_graph.enums import BuiltinNodeTypes from models import EndUser from models.enums import MessageStatus @@ -174,7 +174,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = None repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -244,7 +244,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) @@ -365,9 +365,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock() trace_instance.add_generation = MagicMock() @@ -680,7 +678,7 @@ def test_workflow_trace_handles_usage_extraction_error(trace_instance, monkeypat node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langfuse_trace.langfuse_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py index f656f7435f..bfe916f018 100644 --- a/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py +++ b/api/tests/unit_tests/core/ops/langsmith_trace/test_langsmith_trace.py @@ -3,6 +3,7 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import LangSmithConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.langsmith_trace.entities.langsmith_trace_entity import ( LangSmithRunUpdateModel, ) from core.ops.langsmith_trace.langsmith_trace import LangSmithDataTrace -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser @@ -184,7 +184,7 @@ def test_workflow_trace(trace_instance, monkeypatch): node_retrieval.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other, node_retrieval] + repo.get_by_workflow_execution.return_value = [node_llm, node_other, node_retrieval] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -255,7 +255,7 @@ def test_workflow_trace_no_start_time(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.sessionmaker", lambda bind: lambda: mock_session) monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.DifyCoreRepositoryFactory", mock_factory) @@ -319,9 +319,7 @@ def test_message_trace(trace_instance, monkeypatch): # Mock EndUser lookup mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.langsmith_trace.langsmith_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_run = MagicMock() @@ -565,7 +563,7 @@ def test_workflow_trace_usage_extraction_error(trace_instance, monkeypatch, capl node_llm.metadata = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm] + repo.get_by_workflow_execution.return_value = [node_llm] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo diff --git a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py index cccedaa08c..f4c485a9fc 100644 --- a/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py +++ b/api/tests/unit_tests/core/ops/mlflow_trace/test_mlflow_trace.py @@ -9,6 +9,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import DatabricksConfig, MLflowConfig from core.ops.entities.trace_entity import ( @@ -21,7 +22,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.mlflow_trace.mlflow_trace import MLflowDataTrace, datetime_to_nanoseconds -from dify_graph.enums import BuiltinNodeTypes # ── Helpers ────────────────────────────────────────────────────────────────── @@ -330,7 +330,7 @@ class TestTraceDispatcher: class TestWorkflowTrace: def test_basic_workflow_no_nodes(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -343,7 +343,7 @@ class TestWorkflowTrace: span.end.assert_called_once() def test_workflow_filters_sys_inputs_and_adds_query(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -374,7 +374,7 @@ class TestWorkflowTrace: ), outputs='{"text": "hello world"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [llm_node] + mock_db.session.scalars.return_value.all.return_value = [llm_node] workflow_span = MagicMock() node_span = MagicMock() @@ -397,7 +397,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [qc_node] + mock_db.session.scalars.return_value.all.return_value = [qc_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -411,7 +411,7 @@ class TestWorkflowTrace: node_type=BuiltinNodeTypes.HTTP_REQUEST, process_data='{"url": "https://api.com"}', ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [http_node] + mock_db.session.scalars.return_value.all.return_value = [http_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -434,7 +434,7 @@ class TestWorkflowTrace: } ), ) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [kr_node] + mock_db.session.scalars.return_value.all.return_value = [kr_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -448,7 +448,7 @@ class TestWorkflowTrace: def test_workflow_with_failed_node(self, trace_instance, mock_tracing, mock_db): failed_node = _make_node(status="failed") - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [failed_node] + mock_db.session.scalars.return_value.all.return_value = [failed_node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -459,7 +459,7 @@ class TestWorkflowTrace: node_span.add_event.assert_called_once() def test_workflow_with_workflow_error(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] workflow_span = MagicMock() mock_tracing["start"].return_value = workflow_span mock_tracing["set"].return_value = "token" @@ -473,7 +473,7 @@ class TestWorkflowTrace: def test_workflow_node_no_inputs_no_outputs(self, trace_instance, mock_tracing, mock_db): node = _make_node(inputs=None, outputs=None) - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [node] + mock_db.session.scalars.return_value.all.return_value = [node] workflow_span = MagicMock() node_span = MagicMock() mock_tracing["start"].side_effect = [workflow_span, node_span] @@ -486,7 +486,7 @@ class TestWorkflowTrace: assert end_call.kwargs["outputs"] == {} def test_workflow_no_user_id_no_conversation_id(self, trace_instance, mock_tracing, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -501,7 +501,7 @@ class TestWorkflowTrace: def test_workflow_empty_query(self, trace_instance, mock_tracing, mock_db): """When query is empty string, it's falsy so no query key added.""" - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = [] + mock_db.session.scalars.return_value.all.return_value = [] span = MagicMock() mock_tracing["start"].return_value = span mock_tracing["set"].return_value = "token" @@ -680,12 +680,12 @@ class TestGetMessageUserId: def test_returns_end_user_session_id(self, trace_instance, mock_db): end_user = MagicMock() end_user.session_id = "session-1" - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1"}) assert result == "session-1" def test_returns_account_id_when_no_end_user(self, trace_instance, mock_db): - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None result = trace_instance._get_message_user_id({"from_end_user_id": "eu-1", "from_account_id": "acc-1"}) assert result == "acc-1" @@ -834,7 +834,7 @@ class TestGenerateNameTrace: class TestGetWorkflowNodes: def test_queries_db(self, trace_instance, mock_db): - mock_db.session.query.return_value.filter.return_value.order_by.return_value.all.return_value = ["n1", "n2"] + mock_db.session.scalars.return_value.all.return_value = ["n1", "n2"] result = trace_instance._get_workflow_nodes("run-1") assert result == ["n1", "n2"] diff --git a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py index b2cb7d5109..1cb32f2ee0 100644 --- a/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/opik_trace/test_opik_trace.py @@ -5,6 +5,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from core.ops.entities.config_entity import OpikConfig from core.ops.entities.trace_entity import ( @@ -18,7 +19,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.opik_trace.opik_trace import OpikDataTrace, prepare_opik_uuid, wrap_dict, wrap_metadata -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from models import EndUser from models.enums import MessageStatus @@ -199,7 +199,7 @@ def test_workflow_trace_with_message_id(trace_instance, monkeypatch): node_other.metadata = {WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS.value: 10} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node_llm, node_other] + repo.get_by_workflow_execution.return_value = [node_llm, node_other] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -253,7 +253,7 @@ def test_workflow_trace_no_message_id(trace_instance, monkeypatch): monkeypatch.setattr("core.ops.opik_trace.opik_trace.sessionmaker", lambda bind: lambda: MagicMock()) monkeypatch.setattr("core.ops.opik_trace.opik_trace.db", MagicMock(engine="engine")) repo = MagicMock() - repo.get_by_workflow_run.return_value = [] + repo.get_by_workflow_execution.return_value = [] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) @@ -373,9 +373,7 @@ def test_message_trace_with_end_user(trace_instance, monkeypatch): mock_end_user = MagicMock(spec=EndUser) mock_end_user.session_id = "session-id-123" - mock_query = MagicMock() - mock_query.where.return_value.first.return_value = mock_end_user - monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.query", lambda model: mock_query) + monkeypatch.setattr("core.ops.opik_trace.opik_trace.db.session.get", lambda model, pk: mock_end_user) trace_instance.add_trace = MagicMock(return_value=MagicMock(id="trace_id_2")) trace_instance.add_span = MagicMock() @@ -657,7 +655,7 @@ def test_workflow_trace_usage_extraction_error_fixed(trace_instance, monkeypatch node.outputs = {} repo = MagicMock() - repo.get_by_workflow_run.return_value = [node] + repo.get_by_workflow_execution.return_value = [node] mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo monkeypatch.setattr("core.ops.opik_trace.opik_trace.DifyCoreRepositoryFactory", mock_factory) diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py index a0b6d52720..696f859b6f 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_span_builder.py @@ -1,6 +1,8 @@ from datetime import datetime from unittest.mock import MagicMock, patch +from graphon.entities import WorkflowNodeExecution +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus from opentelemetry.trace import StatusCode from core.ops.entities.trace_entity import ( @@ -25,8 +27,6 @@ from core.ops.tencent_trace.entities.semconv import ( ) from core.ops.tencent_trace.span_builder import TencentSpanBuilder from core.rag.models.document import Document -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus class TestTencentSpanBuilder: diff --git a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py index f259e4639f..382e5dadc3 100644 --- a/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py +++ b/api/tests/unit_tests/core/ops/tencent_trace/test_tencent_trace.py @@ -2,6 +2,8 @@ import logging from unittest.mock import MagicMock, patch import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import BuiltinNodeTypes from core.ops.entities.config_entity import TencentConfig from core.ops.entities.trace_entity import ( @@ -14,8 +16,6 @@ from core.ops.entities.trace_entity import ( WorkflowTraceInfo, ) from core.ops.tencent_trace.tencent_trace import TencentDataTrace -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes from models import Account, App, TenantAccountJoin logger = logging.getLogger(__name__) @@ -413,7 +413,7 @@ class TestTencentDataTrace: with patch( "core.ops.tencent_trace.tencent_trace.SQLAlchemyWorkflowNodeExecutionRepository" ) as mock_repo: - mock_repo.return_value.get_by_workflow_run.return_value = mock_executions + mock_repo.return_value.get_by_workflow_execution.return_value = mock_executions results = tencent_data_trace._get_workflow_node_executions(trace_info) diff --git a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py index 49d6b698ef..6b5cb5b09a 100644 --- a/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py +++ b/api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py @@ -1,7 +1,7 @@ +from graphon.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes from openinference.semconv.trace import OpenInferenceSpanKindValues from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind -from dify_graph.enums import BUILT_IN_NODE_TYPES, BuiltinNodeTypes class TestGetNodeSpanKind: diff --git a/api/tests/unit_tests/core/ops/test_lookup_helpers.py b/api/tests/unit_tests/core/ops/test_lookup_helpers.py new file mode 100644 index 0000000000..86aa68643d --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_lookup_helpers.py @@ -0,0 +1,554 @@ +"""Unit tests for lookup helper functions in core.ops.ops_trace_manager. + +Covers: +- _lookup_app_and_workspace_names +- _lookup_credential_name +- _lookup_llm_credential_info +- TraceTask._get_user_id_from_metadata +""" + +from unittest.mock import MagicMock, patch + +import pytest + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + + +def _make_db_and_session_patches(scalar_side_effect=None, scalar_return_value=None): + """Return (mock_db, cm, session) ready to patch 'core.ops.ops_trace_manager.db' + and 'core.ops.ops_trace_manager.Session'. + + Provide either scalar_side_effect (list, for multiple calls) or + scalar_return_value (single value). + """ + mock_db = MagicMock() + mock_db.engine = MagicMock() + + session = MagicMock() + if scalar_side_effect is not None: + session.scalar.side_effect = scalar_side_effect + else: + session.scalar.return_value = scalar_return_value + + cm = MagicMock() + cm.__enter__ = MagicMock(return_value=session) + cm.__exit__ = MagicMock(return_value=False) + + return mock_db, cm, session + + +# --------------------------------------------------------------------------- +# _lookup_app_and_workspace_names +# --------------------------------------------------------------------------- + + +class TestLookupAppAndWorkspaceNames: + """Tests for _lookup_app_and_workspace_names(app_id, tenant_id).""" + + def test_both_found(self): + """Returns (app_name, workspace_name) when both records exist.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "MyWorkspace" + + def test_app_only_found(self): + """Returns (app_name, '') when tenant record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "MyApp" + assert workspace_name == "" + + def test_tenant_only_found(self): + """Returns ('', workspace_name) when app record is absent.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, "MyWorkspace"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "MyWorkspace" + + def test_neither_found(self): + """Returns ('', '') when both DB lookups return None.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, None]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456") + + assert app_name == "" + assert workspace_name == "" + + def test_none_inputs_skips_db(self): + """Returns ('', '') immediately when both IDs are None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, None) + + mock_session_cls.assert_not_called() + assert app_name == "" + assert workspace_name == "" + + def test_app_id_none_only_queries_tenant(self): + """When app_id is None, only the tenant query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyWorkspace") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names(None, "tenant-456") + + assert app_name == "" + assert workspace_name == "OnlyWorkspace" + assert session.scalar.call_count == 1 + + def test_tenant_id_none_only_queries_app(self): + """When tenant_id is None, only the app query is issued.""" + from core.ops.ops_trace_manager import _lookup_app_and_workspace_names + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyApp") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + app_name, workspace_name = _lookup_app_and_workspace_names("app-123", None) + + assert app_name == "OnlyApp" + assert workspace_name == "" + assert session.scalar.call_count == 1 + + +# --------------------------------------------------------------------------- +# _lookup_credential_name +# --------------------------------------------------------------------------- + + +class TestLookupCredentialName: + """Tests for _lookup_credential_name(credential_id, provider_type).""" + + @pytest.mark.parametrize("provider_type", ["builtin", "plugin", "api", "workflow", "mcp"]) + def test_known_provider_types_return_name(self, provider_type): + """Each valid provider_type results in a DB query and returns the credential name.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="CredentialA") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-123", provider_type) + + assert result == "CredentialA" + session.scalar.assert_called_once() + + def test_credential_not_found_returns_empty_string(self): + """Returns '' when DB yields None for the given credential_id.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + result = _lookup_credential_name("cred-999", "api") + + assert result == "" + + def test_invalid_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately for an unrecognised provider_type — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", "unknown_type") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_credential_id_returns_empty_string_without_db(self): + """Returns '' immediately when credential_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name(None, "api") + + mock_session_cls.assert_not_called() + assert result == "" + + def test_none_provider_type_returns_empty_string_without_db(self): + """Returns '' immediately when provider_type is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_credential_name + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + result = _lookup_credential_name("cred-123", None) + + mock_session_cls.assert_not_called() + assert result == "" + + def test_builtin_and_plugin_map_to_same_model(self): + """Both 'builtin' and 'plugin' provider_types query BuiltinToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import BuiltinToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["builtin"] is BuiltinToolProvider + assert _PROVIDER_TYPE_TO_MODEL["plugin"] is BuiltinToolProvider + + def test_api_maps_to_api_tool_provider(self): + """'api' maps to ApiToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import ApiToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["api"] is ApiToolProvider + + def test_workflow_maps_to_workflow_tool_provider(self): + """'workflow' maps to WorkflowToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import WorkflowToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["workflow"] is WorkflowToolProvider + + def test_mcp_maps_to_mcp_tool_provider(self): + """'mcp' maps to MCPToolProvider.""" + from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL + from models.tools import MCPToolProvider + + assert _PROVIDER_TYPE_TO_MODEL["mcp"] is MCPToolProvider + + +# --------------------------------------------------------------------------- +# _lookup_llm_credential_info +# --------------------------------------------------------------------------- + + +class TestLookupLlmCredentialInfo: + """Tests for _lookup_llm_credential_info(tenant_id, provider, model, model_type).""" + + def _provider_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def _model_record(self, credential_id: str | None = None) -> MagicMock: + record = MagicMock() + record.credential_id = credential_id + return record + + def test_model_level_credential_found(self): + """Returns model-level credential_id and name when ProviderModel has a credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id="model-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel, (3) ProviderModelCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ModelCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "model-cred-id" + assert cred_name == "ModelCredName" + + def test_provider_level_fallback_when_no_model_credential(self): + """Falls back to provider-level credential when ProviderModel has no credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + model_record = self._model_record(credential_id=None) + + # scalar calls: (1) Provider, (2) ProviderModel (no cred), (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, model_record, "ProvCredName"] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_provider_level_fallback_when_no_model_record(self): + """Falls back to provider-level credential when no ProviderModel row exists.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderModel → None, (3) ProviderCredential.credential_name + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, None, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + + def test_no_model_arg_uses_provider_level_only(self): + """When model is None, skips ProviderModel query and uses provider credential.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # scalar calls: (1) Provider, (2) ProviderCredential.credential_name — no ProviderModel + mock_db, cm, session = _make_db_and_session_patches(scalar_side_effect=[provider_record, "ProvCredName"]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", None) + + assert cred_id == "prov-cred-id" + assert cred_name == "ProvCredName" + assert session.scalar.call_count == 2 + + def test_provider_not_found_returns_none_and_empty(self): + """Returns (None, '') when Provider record does not exist.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_none_tenant_id_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when tenant_id is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info(None, "openai", "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_none_provider_returns_none_and_empty_without_db(self): + """Returns (None, '') immediately when provider is None — no DB access.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db = MagicMock() + mock_session_cls = MagicMock() + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", mock_session_cls), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", None, "gpt-4") + + mock_session_cls.assert_not_called() + assert cred_id is None + assert cred_name == "" + + def test_db_error_on_outer_query_returns_none_and_empty(self): + """Returns (None, '') and logs a warning when the outer DB query raises.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + mock_db, cm, session = _make_db_and_session_patches() + session.scalar.side_effect = Exception("DB connection failed") + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + def test_credential_name_lookup_failure_returns_id_with_empty_name(self): + """When credential name sub-query fails, returns cred_id but '' for name.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id="prov-cred-id") + + # Provider found, no model record, then name lookup raises + mock_db, cm, _session = _make_db_and_session_patches( + scalar_side_effect=[provider_record, None, Exception("deleted")] + ) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id == "prov-cred-id" + assert cred_name == "" + + def test_no_credential_on_provider_or_model_returns_none_id(self): + """Returns (None, '') when neither provider nor model has a credential_id.""" + from core.ops.ops_trace_manager import _lookup_llm_credential_info + + provider_record = self._provider_record(credential_id=None) + model_record = self._model_record(credential_id=None) + + mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, model_record]) + + with ( + patch("core.ops.ops_trace_manager.db", mock_db), + patch("core.ops.ops_trace_manager.Session", return_value=cm), + ): + cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4") + + assert cred_id is None + assert cred_name == "" + + +# --------------------------------------------------------------------------- +# TraceTask._get_user_id_from_metadata +# --------------------------------------------------------------------------- + + +class TestGetUserIdFromMetadata: + """Tests for TraceTask._get_user_id_from_metadata(metadata). + + Pure dict logic — no DB access required. + """ + + @pytest.fixture + def get_user_id(self): + """Return the classmethod under test.""" + from core.ops.ops_trace_manager import TraceTask + + return TraceTask._get_user_id_from_metadata + + def test_from_end_user_id_has_highest_priority(self, get_user_id): + """from_end_user_id takes precedence over all other keys.""" + metadata = { + "from_end_user_id": "eu-abc", + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "end_user:eu-abc" + + def test_from_account_id_used_when_no_end_user(self, get_user_id): + """from_account_id is used when from_end_user_id is absent.""" + metadata = { + "from_account_id": "acc-xyz", + "user_id": "u-123", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_user_id_used_when_no_end_user_or_account(self, get_user_id): + """user_id is used when both higher-priority keys are absent.""" + metadata = {"user_id": "u-123"} + assert get_user_id(metadata) == "user:u-123" + + def test_returns_anonymous_when_all_keys_absent(self, get_user_id): + """Returns 'anonymous' when metadata has none of the expected keys.""" + assert get_user_id({}) == "anonymous" + + def test_empty_string_end_user_id_is_skipped(self, get_user_id): + """Empty string for from_end_user_id is falsy and falls through to next key.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "acc-xyz", + } + assert get_user_id(metadata) == "account:acc-xyz" + + def test_empty_string_account_id_is_skipped(self, get_user_id): + """Empty string for from_account_id is falsy and falls through to user_id.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "u-123", + } + assert get_user_id(metadata) == "user:u-123" + + def test_empty_string_user_id_falls_through_to_anonymous(self, get_user_id): + """Empty string for user_id is falsy, so 'anonymous' is returned.""" + metadata = { + "from_end_user_id": "", + "from_account_id": "", + "user_id": "", + } + assert get_user_id(metadata) == "anonymous" + + def test_only_from_end_user_id_present(self, get_user_id): + """Minimal case: only from_end_user_id present.""" + assert get_user_id({"from_end_user_id": "eu-only"}) == "end_user:eu-only" + + def test_irrelevant_keys_do_not_interfere(self, get_user_id): + """Extra metadata keys have no effect on the result.""" + metadata = {"invoke_from": "web", "app_id": "a1"} + assert get_user_id(metadata) == "anonymous" diff --git a/api/tests/unit_tests/core/ops/test_opik_trace.py b/api/tests/unit_tests/core/ops/test_opik_trace.py index 7660967183..ad9d0846be 100644 --- a/api/tests/unit_tests/core/ops/test_opik_trace.py +++ b/api/tests/unit_tests/core/ops/test_opik_trace.py @@ -130,7 +130,7 @@ class TestWorkflowTraceWithoutMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, @@ -262,7 +262,7 @@ class TestWorkflowTraceWithMessageId: def _run(self, trace_info: WorkflowTraceInfo, node_executions: list | None = None): instance = _make_opik_trace_instance() fake_repo = MagicMock() - fake_repo.get_by_workflow_run.return_value = node_executions or [] + fake_repo.get_by_workflow_execution.return_value = node_executions or [] with ( patch("core.ops.opik_trace.opik_trace.db") as mock_db, diff --git a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py index 2d325ccb0e..e47df0121e 100644 --- a/api/tests/unit_tests/core/ops/test_ops_trace_manager.py +++ b/api/tests/unit_tests/core/ops/test_ops_trace_manager.py @@ -86,6 +86,7 @@ def make_message_data(**overrides): created_at = datetime(2025, 2, 20, 12, 0, 0) base = { "id": "msg-id", + "app_id": "app-id", "conversation_id": "conv-id", "created_at": created_at, "updated_at": created_at + timedelta(seconds=3), @@ -156,17 +157,19 @@ def make_workflow_run(): ) -def configure_db_query(session, *, message_file=None, workflow_app_log=None): - def _side_effect(model): - query = MagicMock() - query.filter_by.return_value.first.return_value = None - if message_file and model.__name__ == "MessageFile": - query.filter_by.return_value.first.return_value = message_file - if workflow_app_log and model.__name__ == "WorkflowAppLog": - query.filter_by.return_value.first.return_value = workflow_app_log - return query +def configure_db_scalar(session, *, message_file=None, workflow_app_log=None): + """Configure session.scalar to return appropriate values for MessageFile/WorkflowAppLog lookups.""" + original_scalar = session.scalar - session.query.side_effect = _side_effect + def _side_effect(stmt): + stmt_str = str(stmt) + if "message_file" in stmt_str.lower(): + return message_file + if "workflow_app_log" in stmt_str.lower(): + return workflow_app_log + return original_scalar(stmt) + + session.scalar.side_effect = _side_effect class DummySessionContext: @@ -182,6 +185,9 @@ class DummySessionContext: def __exit__(self, exc_type, exc_val, exc_tb): return False + def execute(self, *args, **kwargs): + return self + def scalar(self, *args, **kwargs): if self._index >= len(self._values): return None @@ -189,6 +195,12 @@ class DummySessionContext: self._index += 1 return value + def scalars(self, *args, **kwargs): + return self + + def all(self): + return [] + @pytest.fixture(autouse=True) def patch_provider_map(monkeypatch): @@ -253,7 +265,7 @@ def workflow_repo_fixture(monkeypatch): def trace_task_message(monkeypatch, mock_db): message_data = make_message_data() monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda msg_id: message_data) - configure_db_query(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) + configure_db_scalar(mock_db, message_file=FakeMessageFile(), workflow_app_log=SimpleNamespace(id="log-id")) return message_data @@ -297,56 +309,53 @@ def test_obfuscated_decrypt_token(encryption_mocks): def test_get_decrypted_tracing_config_returns_config(encryption_mocks, mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc", "other_value": "info"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data app = SimpleNamespace(id="app-id", tenant_id="tenant") - mock_db.scalar.return_value = app + mock_db.scalar.side_effect = [trace_config_data, app] decrypted = OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") assert decrypted["other_value"] == "info" def test_get_decrypted_tracing_config_missing_trace_config(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.scalar.return_value = None assert OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") is None def test_get_decrypted_tracing_config_raises_for_missing_app(mock_db): trace_config_data = SimpleNamespace(tracing_config={"secret_value": "enc"}) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = None + mock_db.scalar.side_effect = [trace_config_data, None] with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_decrypted_tracing_config_raises_for_none_config(mock_db): trace_config_data = SimpleNamespace(tracing_config=None) - mock_db.query.return_value.where.return_value.first.return_value = trace_config_data - mock_db.scalar.return_value = SimpleNamespace(tenant_id="tenant") + mock_db.scalar.side_effect = [trace_config_data, SimpleNamespace(tenant_id="tenant")] with pytest.raises(ValueError, match="Tracing config cannot be None"): OpsTraceManager.get_decrypted_tracing_config("app-id", "dummy") def test_get_ops_trace_instance_handles_none_app(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_returns_none_when_disabled(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": False})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_invalid_provider(mock_db, monkeypatch): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "missing"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr("core.ops.ops_trace_manager.provider_config_map", FakeProviderMap({})) assert OpsTraceManager.get_ops_trace_instance("app-id") is None def test_get_ops_trace_instance_success(monkeypatch, mock_db): app = SimpleNamespace(id="app-id", tracing=json.dumps({"enabled": True, "tracing_provider": "dummy"})) - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app monkeypatch.setattr( "core.ops.ops_trace_manager.OpsTraceManager.get_decrypted_tracing_config", classmethod(lambda cls, aid, provider: {"secret_value": "decrypted", "other_value": "info"}), @@ -380,7 +389,7 @@ def test_get_app_config_through_message_id_app_model_config(mock_db): def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="Invalid tracing provider"): OpsTraceManager.update_app_tracing_config("app", True, "bad") with pytest.raises(ValueError, match="App not found"): @@ -389,26 +398,26 @@ def test_update_app_tracing_config_invalid_provider(mock_db, monkeypatch): def test_update_app_tracing_config_success(mock_db): app = SimpleNamespace(id="app-id", tracing="{}") - mock_db.query.return_value.where.return_value.first.return_value = app + mock_db.get.return_value = app OpsTraceManager.update_app_tracing_config("app-id", True, "dummy") assert app.tracing is not None mock_db.commit.assert_called_once() def test_get_app_tracing_config_errors_when_missing(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = None + mock_db.get.return_value = None with pytest.raises(ValueError, match="App not found"): OpsTraceManager.get_app_tracing_config("app") def test_get_app_tracing_config_returns_defaults(mock_db): - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=None) + mock_db.get.return_value = SimpleNamespace(tracing=None) assert OpsTraceManager.get_app_tracing_config("app-id") == {"enabled": False, "tracing_provider": None} def test_get_app_tracing_config_returns_payload(mock_db): payload = {"enabled": True, "tracing_provider": "dummy"} - mock_db.query.return_value.where.return_value.first.return_value = SimpleNamespace(tracing=json.dumps(payload)) + mock_db.get.return_value = SimpleNamespace(tracing=json.dumps(payload)) assert OpsTraceManager.get_app_tracing_config("app-id") == payload @@ -454,7 +463,7 @@ def test_trace_task_message_trace(trace_task_message, mock_db): def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db): DummySessionContext.scalar_values = ["wf-app-log", "message-ref"] - execution = SimpleNamespace(id_="run-id") + execution = SimpleNamespace(id_="run-id", total_tokens=0) task = TraceTask( trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user" ) @@ -491,7 +500,7 @@ def test_trace_task_dataset_retrieval_trace(trace_task_message): def test_trace_task_tool_trace(monkeypatch, mock_db): custom_message = make_message_data(agent_thoughts=[make_agent_thought("tool-a", datetime(2025, 2, 20, 12, 1, 0))]) monkeypatch.setattr("core.ops.ops_trace_manager.get_message_data", lambda _: custom_message) - configure_db_query(mock_db, message_file=FakeMessageFile()) + configure_db_scalar(mock_db, message_file=FakeMessageFile()) task = TraceTask(trace_type=TraceTaskName.TOOL_TRACE, message_id="msg-id") timer = {"start": 1, "end": 5} result = task.tool_trace("msg-id", timer, tool_name="tool-a", tool_inputs={"foo": 1}, tool_outputs="result") diff --git a/api/tests/unit_tests/core/ops/test_trace_queue_manager.py b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py new file mode 100644 index 0000000000..a4903054e0 --- /dev/null +++ b/api/tests/unit_tests/core/ops/test_trace_queue_manager.py @@ -0,0 +1,194 @@ +"""Unit tests for TraceQueueManager telemetry guard. + +Verifies that TraceQueueManager.add_trace_task() only enqueues tasks when at +least one consumer is active: +- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR +- A third-party trace instance (Langfuse, etc.) is configured + +When neither is active, tasks are silently dropped to avoid unnecessary work. + +When BOTH are false, tasks are silently dropped (correct behavior). +""" + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + + +@pytest.fixture +def trace_queue_manager_and_task(monkeypatch): + """Fixture to provide TraceQueueManager and TraceTask with delayed imports.""" + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type): + self.trace_type = trace_type + self.app_id = None + + class StubTraceQueueManager: + def __init__(self, app_id=None): + self.app_id = app_id + from core.telemetry.gateway import is_enterprise_telemetry_enabled + + self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled() + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + if self._enterprise_telemetry_enabled or self.trace_instance: + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.ops.entities.trace_entity import TraceTaskName + + ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"]) + TraceQueueManager = ops_module.TraceQueueManager + TraceTask = ops_module.TraceTask + + return TraceQueueManager, TraceTask, TraceTaskName + + +class TestTraceQueueManagerTelemetryGuard: + """Test TraceQueueManager's telemetry guard in add_trace_task().""" + + def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task): + """Verify task is NOT enqueued when telemetry disabled and no trace instance. + + This is the core guard: when _enterprise_telemetry_enabled=False AND + trace_instance=None, the task should be silently dropped. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_not_called() + + def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when enterprise telemetry is enabled. + + When _enterprise_telemetry_enabled=True, the task should be enqueued + regardless of trace_instance state. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task): + """Verify task IS enqueued when third-party trace instance is configured. + + When trace_instance is not None (e.g., Langfuse configured), the task + should be enqueued even if enterprise telemetry is disabled. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task): + """Verify task IS enqueued when both telemetry and trace instance are enabled. + + When both _enterprise_telemetry_enabled=True AND trace_instance is set, + the task should definitely be enqueued. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + mock_trace_instance = MagicMock() + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch( + "core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance + ), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="test-app-id") + manager.add_trace_task(trace_task) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "test-app-id" + + def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task): + """Verify app_id is set on the task before enqueuing. + + The guard logic sets trace_task.app_id = self.app_id before calling + trace_manager_queue.put(trace_task). This test verifies that behavior. + """ + TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task + + mock_queue = MagicMock(spec=queue.Queue) + + trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE) + + with ( + patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True), + patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None), + patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue), + ): + manager = TraceQueueManager(app_id="expected-app-id") + manager.add_trace_task(trace_task) + + called_task = mock_queue.put.call_args[0][0] + assert called_task.app_id == "expected-app-id" diff --git a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py index 8057bbbad5..5014f40afc 100644 --- a/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py +++ b/api/tests/unit_tests/core/ops/weave_trace/test_weave_trace.py @@ -7,6 +7,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey from weave.trace_server.trace_server_interface import TraceStatus from core.ops.entities.config_entity import WeaveConfig @@ -22,7 +23,6 @@ from core.ops.entities.trace_entity import ( ) from core.ops.weave_trace.entities.weave_trace_entity import WeaveTraceModel from core.ops.weave_trace.weave_trace import WeaveDataTrace -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionMetadataKey # ── Helpers ────────────────────────────────────────────────────────────────── @@ -589,7 +589,7 @@ class TestWorkflowTrace: nodes = [] repo = MagicMock() - repo.get_by_workflow_run.return_value = nodes + repo.get_by_workflow_execution.return_value = nodes mock_factory = MagicMock() mock_factory.create_workflow_node_execution_repository.return_value = repo @@ -802,8 +802,8 @@ class TestMessageTrace: def test_basic_message_trace(self, trace_instance, monkeypatch): """message_trace creates message run and llm child run.""" monkeypatch.setattr( - "core.ops.weave_trace.weave_trace.db.session.query", - lambda model: MagicMock(where=lambda: MagicMock(first=lambda: None)), + "core.ops.weave_trace.weave_trace.db.session.get", + lambda model, pk: None, ) trace_instance.start_call = MagicMock() @@ -823,7 +823,7 @@ class TestMessageTrace: trace_instance.file_base_url = "http://files.test" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -845,7 +845,7 @@ class TestMessageTrace: end_user.session_id = "session-xyz" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = end_user + mock_db.session.get.return_value = end_user monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -865,7 +865,7 @@ class TestMessageTrace: def test_message_trace_no_end_user(self, trace_instance, monkeypatch): """message_trace handles when from_end_user_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -883,7 +883,7 @@ class TestMessageTrace: def test_message_trace_trace_id_fallback_to_message_id(self, trace_instance, monkeypatch): """trace_id falls back to message_id when trace_id is None.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() @@ -898,7 +898,7 @@ class TestMessageTrace: def test_message_trace_file_list_none(self, trace_instance, monkeypatch): """message_trace handles file_list=None gracefully.""" mock_db = MagicMock() - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.get.return_value = None monkeypatch.setattr("core.ops.weave_trace.weave_trace.db", mock_db) trace_instance.start_call = MagicMock() diff --git a/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py new file mode 100644 index 0000000000..7491e79f30 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/impl/test_model_runtime_factory.py @@ -0,0 +1,36 @@ +from unittest.mock import Mock, patch + +from core.plugin.impl.model_runtime_factory import create_plugin_model_assembly + + +def test_plugin_model_assembly_reuses_single_runtime_across_views(): + runtime = Mock(name="runtime") + provider_factory = Mock(name="provider_factory") + provider_manager = Mock(name="provider_manager") + model_manager = Mock(name="model_manager") + + with ( + patch( + "core.plugin.impl.model_runtime_factory.create_plugin_model_runtime", + return_value=runtime, + ) as mock_runtime_factory, + patch( + "core.plugin.impl.model_runtime_factory.ModelProviderFactory", + return_value=provider_factory, + ) as mock_provider_factory_cls, + patch("core.provider_manager.ProviderManager", return_value=provider_manager) as mock_provider_manager_cls, + patch("core.model_manager.ModelManager", return_value=model_manager) as mock_model_manager_cls, + ): + assembly = create_plugin_model_assembly(tenant_id="tenant-1", user_id="user-1") + + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + assert assembly.model_provider_factory is provider_factory + assert assembly.provider_manager is provider_manager + assert assembly.model_manager is model_manager + + mock_runtime_factory.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_provider_factory_cls.assert_called_once_with(model_runtime=runtime) + mock_provider_manager_cls.assert_called_once_with(model_runtime=runtime) + mock_model_manager_cls.assert_called_once_with(provider_manager=provider_manager) diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py index c2778f082b..3feb4159ad 100644 --- a/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_app.py @@ -332,27 +332,21 @@ class TestPluginAppBackwardsInvocation: PluginAppBackwardsInvocation._get_user("uid") def test_get_app_returns_app(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain app_obj = MagicMock(id="app") - query_chain.first.return_value = app_obj - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj))) mocker.patch("core.plugin.backwards_invocation.app.db", db) assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj def test_get_app_raises_when_missing(self, mocker): - query_chain = MagicMock() - query_chain.where.return_value = query_chain - query_chain.first.return_value = None - db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): PluginAppBackwardsInvocation._get_app("app", "tenant") def test_get_app_raises_when_query_fails(self, mocker): - db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down")))) + db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down")))) mocker.patch("core.plugin.backwards_invocation.app.db", db) with pytest.raises(ValueError, match="app not found"): diff --git a/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py new file mode 100644 index 0000000000..543b278715 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_backwards_invocation_model.py @@ -0,0 +1,62 @@ +from types import SimpleNamespace +from unittest.mock import patch + +from graphon.model_runtime.entities.message_entities import UserPromptMessage + +from core.plugin.backwards_invocation.model import PluginModelBackwardsInvocation +from core.plugin.entities.request import RequestInvokeSummary + + +def test_system_model_helpers_forward_user_id(): + with ( + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.get_max_llm_context_tokens", + return_value=4096, + ) as mock_max_tokens, + patch( + "core.plugin.backwards_invocation.model.ModelInvocationUtils.calculate_tokens", + return_value=7, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.get_system_model_max_tokens("tenant-1", user_id="user-1") == 4096 + assert ( + PluginModelBackwardsInvocation.get_prompt_tokens( + "tenant-1", + [UserPromptMessage(content="hello")], + user_id="user-1", + ) + == 7 + ) + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="user-1", + ) + + +def test_invoke_summary_uses_same_user_scope_for_token_helpers(): + tenant = SimpleNamespace(id="tenant-1") + payload = RequestInvokeSummary(text="short", instruction="keep it concise") + + with ( + patch.object( + PluginModelBackwardsInvocation, + "get_system_model_max_tokens", + return_value=100, + ) as mock_max_tokens, + patch.object( + PluginModelBackwardsInvocation, + "get_prompt_tokens", + return_value=10, + ) as mock_prompt_tokens, + ): + assert PluginModelBackwardsInvocation.invoke_summary("user-1", tenant, payload) == "short" + + mock_max_tokens.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_prompt_tokens.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="short")], + user_id="user-1", + ) diff --git a/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py new file mode 100644 index 0000000000..f8d0e127b1 --- /dev/null +++ b/api/tests/unit_tests/core/plugin/test_model_runtime_adapter.py @@ -0,0 +1,506 @@ +"""Unit tests for the plugin-backed model runtime adapter.""" + +import datetime +import uuid +from types import SimpleNamespace +from unittest.mock import Mock, sentinel + +import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod, ProviderEntity + +from core.plugin.entities.plugin_daemon import PluginModelProviderEntity +from core.plugin.impl import model_runtime as model_runtime_module +from core.plugin.impl.model import PluginModelClient +from core.plugin.impl.model_runtime import TENANT_SCOPE_SCHEMA_CACHE_USER_ID, PluginModelRuntime +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime + + +def _build_model_schema() -> AIModelEntity: + return AIModelEntity( + model="gpt-4o-mini", + label=I18nObject(en_US="GPT-4o mini"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.PREDEFINED_MODEL, + model_properties={}, + ) + + +class TestPluginModelRuntime: + """Validate the adapter keeps plugin-specific routing out of the runtime port.""" + + def test_fetch_model_providers_returns_runtime_entities(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert len(providers) == 1 + assert providers[0].provider == "langgenius/openai/openai" + assert providers[0].provider_name == "openai" + assert providers[0].label.en_US == "OpenAI" + client.fetch_model_providers.assert_called_once_with("tenant") + + def test_fetch_model_providers_only_exposes_short_name_for_canonical_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="acme/openai/openai", + plugin_id="acme/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="Acme OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ), + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + provider_aliases = {provider.provider: provider.provider_name for provider in providers} + assert provider_aliases["acme/openai/openai"] == "" + assert provider_aliases["langgenius/openai/openai"] == "openai" + + def test_fetch_model_providers_keeps_google_alias_on_canonical_gemini_provider(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="google", + tenant_id="tenant", + plugin_unique_identifier="langgenius/gemini/google", + plugin_id="langgenius/gemini", + declaration=ProviderEntity( + provider="google", + label=I18nObject(en_US="Google"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + providers = runtime.fetch_model_providers() + + assert providers[0].provider == "langgenius/gemini/google" + assert providers[0].provider_name == "google" + + def test_validate_provider_credentials_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.validate_provider_credentials( + provider="langgenius/openai/openai", + credentials={"api_key": "secret"}, + ) + + client.validate_provider_credentials.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + credentials={"api_key": "secret"}, + ) + + def test_invoke_llm_resolves_plugin_fields(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.invoke_llm( + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + assert result is sentinel.result + client.invoke_llm.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + ) + + def test_invoke_llm_rejects_per_call_user_override(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_llm.return_value = sentinel.result + runtime = PluginModelRuntime(tenant_id="tenant", user_id="bound-user", client=client) + + with pytest.raises(TypeError, match="unexpected keyword argument 'user_id'"): + runtime.invoke_llm( # type: ignore[call-arg] + provider="langgenius/openai/openai", + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + model_parameters={"temperature": 0.3}, + prompt_messages=[], + tools=None, + stop=None, + stream=False, + user_id="request-user", + ) + + client.invoke_llm.assert_not_called() + + def test_invoke_tts_uses_bound_runtime_user_when_runtime_is_unbound(self) -> None: + client = Mock(spec=PluginModelClient) + client.invoke_tts.return_value = iter([b"chunk"]) + runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=client) + + result = runtime.invoke_tts( + provider="langgenius/openai/openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + assert list(result) == [b"chunk"] + client.invoke_tts.assert_called_once_with( + tenant_id="tenant", + user_id=None, + plugin_id="langgenius/openai", + provider="openai", + model="tts-1", + credentials={"api_key": "secret"}, + content_text="hello", + voice="alloy", + ) + + def test_fetch_model_providers_uses_bound_runtime_cache(self) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + runtime.fetch_model_providers() + runtime.fetch_model_providers() + + client.fetch_model_providers.assert_called_once_with("tenant") + + +def test_create_plugin_model_runtime_without_user_context() -> None: + runtime = create_plugin_model_runtime(tenant_id="tenant") + + assert runtime.user_id is None + + +def test_plugin_model_runtime_requires_client() -> None: + with pytest.raises(ValueError, match="client is required"): + PluginModelRuntime(tenant_id="tenant", user_id="user", client=None) # type: ignore[arg-type] + + +def test_get_model_schema_uses_cached_schema_without_hitting_client(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value=schema.model_dump_json()), + delete=Mock(), + setex=Mock(), + ), + ) + + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + client.get_model_schema.assert_not_called() + + +def test_get_model_schema_deletes_invalid_cache_and_refetches(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + schema = _build_model_schema() + delete = Mock() + setex = Mock() + monkeypatch.setattr( + model_runtime_module, + "redis_client", + SimpleNamespace( + get=Mock(return_value="not-json"), + delete=delete, + setex=setex, + ), + ) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_MODEL_SCHEMA_CACHE_TTL", 300) + client.get_model_schema.return_value = schema + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + result = runtime.get_model_schema( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + + assert result == schema + delete.assert_called_once() + client.get_model_schema.assert_called_once_with( + tenant_id="tenant", + user_id="user", + plugin_id="langgenius/openai", + provider="openai", + model_type=ModelType.LLM.value, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + ) + setex.assert_called_once() + + +def test_get_llm_num_tokens_returns_zero_when_plugin_counting_is_disabled(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + monkeypatch.setattr(model_runtime_module.dify_config, "PLUGIN_BASED_TOKEN_COUNTING_ENABLED", False) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert ( + runtime.get_llm_num_tokens( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"api_key": "secret"}, + prompt_messages=[], + tools=None, + ) + == 0 + ) + client.get_llm_num_tokens.assert_not_called() + + +def test_get_provider_icon_reads_requested_variant_and_detects_svg_mime(monkeypatch: pytest.MonkeyPatch) -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + icon_small=I18nObject(en_US="logo.svg"), + icon_small_dark=I18nObject(en_US="logo-dark.png"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + fetch_asset = Mock(return_value=b"") + monkeypatch.setattr(model_runtime_module.PluginAssetManager, "fetch_asset", fetch_asset) + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + icon_bytes, mime_type = runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small", + lang="en_US", + ) + + assert icon_bytes == b"" + assert mime_type == "image/svg+xml" + fetch_asset.assert_called_once_with(tenant_id="tenant", id="logo.svg") + + +def test_get_provider_icon_rejects_unsupported_types_and_missing_variants() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + with pytest.raises(ValueError, match="does not have small dark icon"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_small_dark", + lang="en_US", + ) + + with pytest.raises(ValueError, match="Unsupported icon type"): + runtime.get_provider_icon( + provider="langgenius/openai/openai", + icon_type="icon_large", + lang="en_US", + ) + + +def test_get_schema_cache_key_is_stable_across_credential_order() -> None: + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=Mock(spec=PluginModelClient)) + + first = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"b": "2", "a": "1"}, + ) + second = runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1", "b": "2"}, + ) + + assert first == second + + +def test_get_schema_cache_key_separates_distinct_user_scopes() -> None: + first_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + second_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-b", client=Mock(spec=PluginModelClient)) + + first = first_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + second = second_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert first != second + + +def test_get_schema_cache_key_separates_tenant_scope_from_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="user-a", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + user_key = user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={"a": "1"}, + ) + + assert tenant_key != user_key + assert f":{TENANT_SCOPE_SCHEMA_CACHE_USER_ID}" in tenant_key + + +def test_get_schema_cache_key_separates_tenant_scope_from_empty_string_user_scope() -> None: + tenant_runtime = PluginModelRuntime(tenant_id="tenant", user_id=None, client=Mock(spec=PluginModelClient)) + empty_user_runtime = PluginModelRuntime(tenant_id="tenant", user_id="", client=Mock(spec=PluginModelClient)) + + tenant_key = tenant_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + empty_user_key = empty_user_runtime._get_schema_cache_key( + provider="langgenius/openai/openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + credentials={}, + ) + + assert tenant_key != empty_user_key + assert empty_user_key.endswith(":") + assert TENANT_SCOPE_SCHEMA_CACHE_USER_ID not in empty_user_key + + +def test_get_provider_schema_supports_short_alias_and_rejects_invalid_provider() -> None: + client = Mock(spec=PluginModelClient) + client.fetch_model_providers.return_value = [ + PluginModelProviderEntity( + id=uuid.uuid4().hex, + created_at=datetime.datetime.now(), + updated_at=datetime.datetime.now(), + provider="openai", + tenant_id="tenant", + plugin_unique_identifier="langgenius/openai/openai", + plugin_id="langgenius/openai", + declaration=ProviderEntity( + provider="openai", + label=I18nObject(en_US="OpenAI"), + supported_model_types=[], + configurate_methods=[ConfigurateMethod.PREDEFINED_MODEL], + ), + ) + ] + runtime = PluginModelRuntime(tenant_id="tenant", user_id="user", client=client) + + assert runtime._get_provider_schema("openai").provider == "langgenius/openai/openai" + + with pytest.raises(ValueError, match="Invalid provider"): + runtime._get_provider_schema("missing") diff --git a/api/tests/unit_tests/core/plugin/test_plugin_entities.py b/api/tests/unit_tests/core/plugin/test_plugin_entities.py index b0b64a601b..a812b01c5b 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_entities.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_entities.py @@ -4,6 +4,12 @@ from enum import StrEnum import pytest from flask import Response +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) from pydantic import ValidationError from core.plugin.entities.endpoint import EndpointEntityWithInstance @@ -25,12 +31,6 @@ from core.plugin.entities.request import ( ) from core.plugin.utils.http_parser import serialize_response from core.tools.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) class TestEndpointEntity: diff --git a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py index 4f038d4a5b..3063ca0197 100644 --- a/api/tests/unit_tests/core/plugin/test_plugin_runtime.py +++ b/api/tests/unit_tests/core/plugin/test_plugin_runtime.py @@ -17,6 +17,14 @@ from unittest.mock import MagicMock, patch import httpx import pytest +from graphon.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from graphon.model_runtime.errors.validate import CredentialsValidateFailedError from pydantic import BaseModel from core.plugin.entities.plugin_daemon import ( @@ -26,6 +34,7 @@ from core.plugin.entities.plugin_daemon import ( from core.plugin.impl.base import BasePluginClient from core.plugin.impl.exc import ( PluginDaemonBadRequestError, + PluginDaemonClientSideError, PluginDaemonInternalServerError, PluginDaemonNotFoundError, PluginDaemonUnauthorizedError, @@ -36,14 +45,6 @@ from core.plugin.impl.exc import ( ) from core.plugin.impl.plugin import PluginInstaller from core.plugin.impl.tool import PluginToolManager -from dify_graph.model_runtime.errors.invoke import ( - InvokeAuthorizationError, - InvokeBadRequestError, - InvokeConnectionError, - InvokeRateLimitError, - InvokeServerUnavailableError, -) -from dify_graph.model_runtime.errors.validate import CredentialsValidateFailedError class TestPluginRuntimeExecution: @@ -557,7 +558,7 @@ class TestPluginRuntimeErrorHandling: with patch("httpx.request", return_value=mock_response, autospec=True): # Act & Assert - with pytest.raises(httpx.HTTPStatusError): + with pytest.raises(PluginDaemonInternalServerError): plugin_client._request_with_plugin_daemon_response("GET", "plugin/test-tenant/test", bool) def test_empty_data_response_error(self, plugin_client, mock_config): @@ -1808,8 +1809,8 @@ class TestPluginInstallerAdvanced: mock_response.raise_for_status = raise_for_status with patch("httpx.request", return_value=mock_response, autospec=True): - # Act & Assert - Should raise HTTPStatusError for 404 - with pytest.raises(httpx.HTTPStatusError): + # Act & Assert - Should raise PluginDaemonClientSideError for 404 + with pytest.raises(PluginDaemonClientSideError): installer.fetch_plugin_readme("test-tenant", "test-org/test-plugin", "en") def test_list_plugins_with_pagination(self, installer, mock_config): diff --git a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py index c7e94aa4cf..90730dff5a 100644 --- a/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py +++ b/api/tests/unit_tests/core/plugin/utils/test_chunk_merger.py @@ -1,13 +1,12 @@ from collections.abc import Generator import pytest +from graphon.file import File, FileTransferMethod, FileType from core.agent.entities import AgentInvokeMessage from core.plugin.utils.chunk_merger import FileChunk, merge_blob_chunks from core.plugin.utils.converter import convert_parameters_to_plugin_format from core.tools.entities.tool_entities import ToolInvokeMessage, ToolParameter, ToolSelector -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File class TestChunkMerger: diff --git a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py index 3d08525aba..2b280dd674 100644 --- a/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_advanced_prompt_transform.py @@ -2,15 +2,8 @@ from typing import cast from unittest.mock import MagicMock, patch import pytest - -from configs import dify_config -from core.app.app_config.entities import ModelConfigEntity -from core.memory.token_buffer_memory import TokenBufferMemory -from core.prompt.advanced_prompt_transform import AdvancedPromptTransform -from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig -from core.prompt.utils.prompt_template_parser import PromptTemplateParser -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessageRole, @@ -18,6 +11,13 @@ from dify_graph.model_runtime.entities.message_entities import ( TextPromptMessageContent, UserPromptMessage, ) + +from configs import dify_config +from core.app.app_config.entities import ModelConfigEntity +from core.memory.token_buffer_memory import TokenBufferMemory +from core.prompt.advanced_prompt_transform import AdvancedPromptTransform +from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig +from core.prompt.utils.prompt_template_parser import PromptTemplateParser from models.model import Conversation @@ -145,7 +145,7 @@ def test__get_chat_model_prompt_messages_with_files_no_memory(get_chat_model_arg prompt_transform = AdvancedPromptTransform() prompt_transform._calculate_rest_token = MagicMock(return_value=2000) - with patch("dify_graph.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: + with patch("graphon.file.file_manager.to_prompt_message_content", autospec=True) as mock_get_encoded_string: mock_get_encoded_string.return_value = ImagePromptMessageContent( url=str(files[0].remote_url), format="jpg", mime_type="image/jpg" ) diff --git a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py index 634703740c..4a54649b28 100644 --- a/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_agent_history_prompt_transform.py @@ -1,18 +1,19 @@ from unittest.mock import MagicMock +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + ToolPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel + from core.app.entities.app_invoke_entities import ( ModelConfigWithCredentialsEntity, ) from core.entities.provider_configuration import ProviderModelBundle from core.memory.token_buffer_memory import TokenBufferMemory from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - SystemPromptMessage, - ToolPromptMessage, - UserPromptMessage, -) -from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel from models.model import Conversation diff --git a/api/tests/unit_tests/core/prompt/test_prompt_message.py b/api/tests/unit_tests/core/prompt/test_prompt_message.py index 9fc300348a..a4b3960b0a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_message.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_message.py @@ -1,6 +1,4 @@ -from core.prompt.simple_prompt_transform import ModelMode -from core.prompt.utils.prompt_message_util import PromptMessageUtil -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, AudioPromptMessageContent, ImagePromptMessageContent, @@ -9,6 +7,9 @@ from dify_graph.model_runtime.entities.message_entities import ( UserPromptMessage, ) +from core.prompt.simple_prompt_transform import ModelMode +from core.prompt.utils.prompt_message_util import PromptMessageUtil + def test_build_prompt_message_with_prompt_message_contents(): prompt = UserPromptMessage(content=[TextPromptMessageContent(data="Hello, World!")]) diff --git a/api/tests/unit_tests/core/prompt/test_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_prompt_transform.py index d379e3067a..e35ce2c48a 100644 --- a/api/tests/unit_tests/core/prompt/test_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_prompt_transform.py @@ -2,16 +2,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.prompt.prompt_transform import PromptTransform -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey # from core.app.app_config.entities import ModelConfigEntity # from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -# from dify_graph.model_runtime.entities.message_entities import UserPromptMessage -# from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule -# from dify_graph.model_runtime.entities.provider_entities import ProviderEntity -# from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel +# from graphon.model_runtime.entities.message_entities import UserPromptMessage +# from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelPropertyKey, ParameterRule +# from graphon.model_runtime.entities.provider_entities import ProviderEntity +# from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel # from core.prompt.prompt_transform import PromptTransform diff --git a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py index e6d28224d7..3f188cfbb4 100644 --- a/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py +++ b/api/tests/unit_tests/core/prompt/test_simple_prompt_transform.py @@ -2,6 +2,12 @@ from types import SimpleNamespace from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + ImagePromptMessageContent, + TextPromptMessageContent, + UserPromptMessage, +) from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity from core.memory.token_buffer_memory import TokenBufferMemory @@ -18,12 +24,6 @@ from core.prompt.prompt_templates.advanced_prompt_templates import ( CONTEXT, ) from core.prompt.simple_prompt_transform import SimplePromptTransform -from dify_graph.model_runtime.entities.message_entities import ( - AssistantPromptMessage, - ImagePromptMessageContent, - TextPromptMessageContent, - UserPromptMessage, -) from models.model import AppMode, Conversation diff --git a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py index 538457ccc8..006b4e7345 100644 --- a/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py +++ b/api/tests/unit_tests/core/rag/data_post_processor/test_data_post_processor.py @@ -1,14 +1,13 @@ from unittest.mock import MagicMock, patch -import pytest +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.errors.invoke import InvokeAuthorizationError from core.rag.data_post_processor.data_post_processor import DataPostProcessor from core.rag.data_post_processor.reorder import ReorderRunner from core.rag.index_processor.constant.query_type import QueryType from core.rag.models.document import Document from core.rag.rerank.rerank_type import RerankMode -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.errors.invoke import InvokeAuthorizationError def _doc(content: str) -> Document: @@ -56,7 +55,6 @@ class TestDataPostProcessor: documents=original_documents, score_threshold=0.3, top_n=2, - user="user-1", query_type=QueryType.IMAGE_QUERY, ) @@ -65,7 +63,6 @@ class TestDataPostProcessor: original_documents, 0.3, 2, - "user-1", QueryType.IMAGE_QUERY, ) processor.reorder_runner.run.assert_called_once_with(reranked_documents) @@ -176,25 +173,24 @@ class TestDataPostProcessor: processor = DataPostProcessor.__new__(DataPostProcessor) assert processor._get_rerank_model_instance("tenant-1", None) is None - def test_get_rerank_model_instance_raises_key_error_for_incomplete_config(self): + def test_get_rerank_model_instance_returns_none_for_incomplete_config(self): processor = DataPostProcessor.__new__(DataPostProcessor) - with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: - manager_instance = manager_cls.return_value - with pytest.raises(KeyError, match="reranking_model_name"): - processor._get_rerank_model_instance( - tenant_id="tenant-1", - reranking_model={"reranking_provider_name": "provider-x"}, - ) + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + result = processor._get_rerank_model_instance( + tenant_id="tenant-1", + reranking_model={"reranking_provider_name": "provider-x"}, + ) - manager_instance.get_model_instance.assert_not_called() + assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") def test_get_rerank_model_instance_success(self): processor = DataPostProcessor.__new__(DataPostProcessor) model_instance = object() - with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: - manager_instance = manager_cls.return_value + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value manager_instance.get_model_instance.return_value = model_instance result = processor._get_rerank_model_instance( @@ -206,6 +202,7 @@ class TestDataPostProcessor: ) assert result is model_instance + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") manager_instance.get_model_instance.assert_called_once_with( tenant_id="tenant-1", provider="provider-x", @@ -216,8 +213,8 @@ class TestDataPostProcessor: def test_get_rerank_model_instance_handles_authorization_error(self): processor = DataPostProcessor.__new__(DataPostProcessor) - with patch("core.rag.data_post_processor.data_post_processor.ModelManager") as manager_cls: - manager_instance = manager_cls.return_value + with patch("core.rag.data_post_processor.data_post_processor.ModelManager.for_tenant") as for_tenant_mock: + manager_instance = for_tenant_mock.return_value manager_instance.get_model_instance.side_effect = InvokeAuthorizationError("not authorized") result = processor._get_rerank_model_instance( @@ -229,6 +226,7 @@ class TestDataPostProcessor: ) assert result is None + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") class TestReorderRunner: diff --git a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py index 8c1e4e478b..63de4b8af2 100644 --- a/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py +++ b/api/tests/unit_tests/core/rag/datasource/test_datasource_retrieval.py @@ -399,7 +399,7 @@ class TestRetrievalServiceInternals: assert exceptions == [] vector_instance.search_by_file.assert_not_called() - @patch("core.rag.datasource.retrieval_service.ModelManager") + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") @patch("core.rag.datasource.retrieval_service.DataPostProcessor") @patch("core.rag.datasource.retrieval_service.Vector") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -451,9 +451,10 @@ class TestRetrievalServiceInternals: assert all_documents == reranked_docs assert exceptions == [] processor_instance.invoke.assert_called_once() + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) model_manager.check_model_support_vision.assert_called_once() - @patch("core.rag.datasource.retrieval_service.ModelManager") + @patch("core.rag.datasource.retrieval_service.ModelManager.for_tenant") @patch("core.rag.datasource.retrieval_service.DataPostProcessor") @patch("core.rag.datasource.retrieval_service.Vector") @patch("core.rag.datasource.retrieval_service.RetrievalService._get_dataset") @@ -503,6 +504,7 @@ class TestRetrievalServiceInternals: assert all_documents == original_docs assert exceptions == [] + mock_model_manager_class.assert_called_once_with(tenant_id=internal_dataset.tenant_id) processor_instance.invoke.assert_not_called() @patch("core.rag.datasource.retrieval_service.DataPostProcessor") diff --git a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py index dd536af759..54ad6d330b 100644 --- a/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py +++ b/api/tests/unit_tests/core/rag/datasource/vdb/test_vector_factory.py @@ -384,7 +384,8 @@ def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatc model_manager = MagicMock() model_manager.get_model_instance.return_value = "model-instance" - monkeypatch.setattr(vector_factory_module, "ModelManager", MagicMock(return_value=model_manager)) + for_tenant_mock = MagicMock(return_value=model_manager) + monkeypatch.setattr(vector_factory_module.ModelManager, "for_tenant", for_tenant_mock) monkeypatch.setattr(vector_factory_module, "CacheEmbedding", MagicMock(return_value="cached-embedding")) vector = vector_factory_module.Vector.__new__(vector_factory_module.Vector) @@ -397,6 +398,7 @@ def test_get_embeddings_builds_cache_embedding(vector_factory_module, monkeypatc result = vector._get_embeddings() assert result == "cached-embedding" + for_tenant_mock.assert_called_once_with(tenant_id="tenant-1") model_manager.get_model_instance.assert_called_once_with( tenant_id="tenant-1", provider="openai", diff --git a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py index 13285cdad0..3ba0628fe2 100644 --- a/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py +++ b/api/tests/unit_tests/core/rag/docstore/test_dataset_docstore.py @@ -163,7 +163,7 @@ class TestDatasetDocumentStoreAddDocuments: with ( patch("core.rag.docstore.dataset_docstore.db") as mock_db, - patch("core.rag.docstore.dataset_docstore.ModelManager") as mock_manager_class, + patch("core.rag.docstore.dataset_docstore.ModelManager.for_tenant") as mock_manager_class, ): mock_session = MagicMock() mock_db.session = mock_session diff --git a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py index a0db25174d..6fd44be4d4 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py +++ b/api/tests/unit_tests/core/rag/embedding/test_cached_embedding.py @@ -12,11 +12,11 @@ from unittest.mock import Mock, patch import numpy as np import pytest +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from sqlalchemy.exc import IntegrityError from core.rag.embedding.cached_embedding import CacheEmbedding -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage from models.dataset import Embedding @@ -28,6 +28,7 @@ class TestCacheEmbeddingMultimodalDocuments: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} @@ -64,7 +65,7 @@ class TestCacheEmbeddingMultimodalDocuments: def test_embed_single_multimodal_document_cache_miss(self, mock_model_instance, sample_multimodal_result): """Test embedding a single multimodal document when cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) documents = [{"file_id": "file123", "content": "test content"}] with patch("core.rag.embedding.cached_embedding.db.session") as mock_session: @@ -316,13 +317,14 @@ class TestCacheEmbeddingMultimodalQuery: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "vision-embedding-model" + model_instance.model_name = "vision-embedding-model" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance def test_embed_multimodal_query_cache_miss(self, mock_model_instance): """Test embedding multimodal query when Redis cache is empty.""" - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) document = {"file_id": "file123"} vector = np.random.randn(1536) @@ -467,6 +469,7 @@ class TestCacheEmbeddingQueryErrors: """Create a mock ModelInstance for testing.""" model_instance = Mock() model_instance.model = "text-embedding-ada-002" + model_instance.model_name = "text-embedding-ada-002" model_instance.provider = "openai" model_instance.credentials = {"api_key": "test-key"} return model_instance @@ -532,24 +535,13 @@ class TestCacheEmbeddingQueryErrors: class TestCacheEmbeddingInitialization: """Test suite for CacheEmbedding initialization.""" - def test_initialization_with_user(self): - """Test CacheEmbedding initialization with user parameter.""" - model_instance = Mock() - model_instance.model = "test-model" - model_instance.provider = "test-provider" - - cache_embedding = CacheEmbedding(model_instance, user="test-user") - - assert cache_embedding._model_instance == model_instance - assert cache_embedding._user == "test-user" - - def test_initialization_without_user(self): - """Test CacheEmbedding initialization without user parameter.""" + def test_initialization_sets_model_instance(self): + """Test CacheEmbedding initialization stores the provided model instance.""" model_instance = Mock() model_instance.model = "test-model" + model_instance.model_name = "test-model" model_instance.provider = "test-provider" cache_embedding = CacheEmbedding(model_instance) assert cache_embedding._model_instance == model_instance - assert cache_embedding._user is None diff --git a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py index 6e71f0c61f..d7ba944e58 100644 --- a/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py +++ b/api/tests/unit_tests/core/rag/embedding/test_embedding_service.py @@ -49,17 +49,17 @@ from unittest.mock import Mock, patch import numpy as np import pytest -from sqlalchemy.exc import IntegrityError - -from core.entities.embedding_type import EmbeddingInputType -from core.rag.embedding.cached_embedding import CacheEmbedding -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.entities.text_embedding_entities import EmbeddingResult, EmbeddingUsage +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeConnectionError, InvokeRateLimitError, ) +from sqlalchemy.exc import IntegrityError + +from core.entities.embedding_type import EmbeddingInputType +from core.rag.embedding.cached_embedding import CacheEmbedding from models.dataset import Embedding @@ -134,7 +134,7 @@ class TestCacheEmbeddingDocuments: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Python is a programming language"] # Mock database query to return no cached embedding (cache miss) @@ -156,7 +156,6 @@ class TestCacheEmbeddingDocuments: # Verify model was invoked with correct parameters mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=texts, - user="test-user", input_type=EmbeddingInputType.DOCUMENT, ) @@ -612,7 +611,7 @@ class TestCacheEmbeddingQuery: - Correct return value """ # Arrange - cache_embedding = CacheEmbedding(mock_model_instance, user="test-user") + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is Python?" # Create embedding result @@ -651,7 +650,6 @@ class TestCacheEmbeddingQuery: # Verify model was invoked with QUERY input type mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user="test-user", input_type=EmbeddingInputType.QUERY, ) @@ -1568,25 +1566,16 @@ class TestEmbeddingEdgeCases: norm = np.linalg.norm(emb) assert abs(norm - 1.0) < 0.01 - def test_embed_query_with_user_context(self, mock_model_instance): - """Test query embedding with user context parameter. + def test_embed_query_uses_bound_model_instance(self, mock_model_instance): + """Test query embedding using the provided model instance. Verifies: - - User parameter is passed correctly to model - - User context is used for tracking/logging - - Embedding generation works with user context - - Context: - -------- - The user parameter is important for: - 1. Usage tracking per user - 2. Rate limiting per user - 3. Audit logging - 4. Personalization (in some models) + - Embedding generation works with the injected model instance + - Query input type is preserved + - No extra binding step is required at call time """ # Arrange - user_id = "user-12345" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) query = "What is machine learning?" # Create embedding @@ -1620,24 +1609,20 @@ class TestEmbeddingEdgeCases: assert isinstance(result, list) assert len(result) == 1536 - # Verify user parameter was passed to model mock_model_instance.invoke_text_embedding.assert_called_once_with( texts=[query], - user=user_id, input_type=EmbeddingInputType.QUERY, ) - def test_embed_documents_with_user_context(self, mock_model_instance): - """Test document embedding with user context parameter. + def test_embed_documents_uses_bound_model_instance(self, mock_model_instance): + """Test document embedding using the provided model instance. Verifies: - - User parameter is passed correctly for document embeddings - - Batch processing maintains user context - - User tracking works across batches + - Batch processing uses the injected model instance + - Document input type is preserved """ # Arrange - user_id = "user-67890" - cache_embedding = CacheEmbedding(mock_model_instance, user=user_id) + cache_embedding = CacheEmbedding(mock_model_instance) texts = ["Document 1", "Document 2"] # Create embeddings @@ -1673,10 +1658,8 @@ class TestEmbeddingEdgeCases: # Assert assert len(result) == 2 - # Verify user parameter was passed mock_model_instance.invoke_text_embedding.assert_called_once() call_args = mock_model_instance.invoke_text_embedding.call_args - assert call_args.kwargs["user"] == user_id assert call_args.kwargs["input_type"] == EmbeddingInputType.DOCUMENT diff --git a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py index 2c234edd9a..cc2873dd3f 100644 --- a/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py +++ b/api/tests/unit_tests/core/rag/indexing/processor/test_paragraph_index_processor.py @@ -2,14 +2,14 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMResult, LLMUsage +from graphon.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent +from graphon.model_runtime.entities.model_entities import ModelFeature from core.entities.knowledge_entities import PreviewDetail from core.rag.index_processor.constant.index_type import IndexTechniqueType from core.rag.index_processor.processor.paragraph_index_processor import ParagraphIndexProcessor from core.rag.models.document import AttachmentDocument, Document -from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMUsage -from dify_graph.model_runtime.entities.message_entities import AssistantPromptMessage, ImagePromptMessageContent -from dify_graph.model_runtime.entities.model_entities import ModelFeature class TestParagraphIndexProcessor: @@ -400,7 +400,9 @@ class TestParagraphIndexProcessor: model_instance.invoke_llm.return_value = self._llm_result("text summary") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -411,7 +413,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, usage = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -434,7 +436,9 @@ class TestParagraphIndexProcessor: image_content = ImagePromptMessageContent(format="url", mime_type="image/png", url="http://example.com/a.png") with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -449,7 +453,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.deduct_llm_quota"), ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() summary, _ = ParagraphIndexProcessor.generate_summary( "tenant-1", "text content", @@ -470,7 +474,9 @@ class TestParagraphIndexProcessor: image_file = SimpleNamespace() with ( - patch("core.rag.index_processor.processor.paragraph_index_processor.ProviderManager") as mock_pm_cls, + patch( + "core.rag.index_processor.processor.paragraph_index_processor.create_plugin_provider_manager" + ) as mock_provider_manager, patch( "core.rag.index_processor.processor.paragraph_index_processor.ModelInstance", return_value=model_instance, @@ -487,7 +493,7 @@ class TestParagraphIndexProcessor: ), patch("core.rag.index_processor.processor.paragraph_index_processor.logger") as mock_logger, ): - mock_pm_cls.return_value.get_provider_model_bundle.return_value = Mock() + mock_provider_manager.return_value.get_provider_model_bundle.return_value = Mock() with pytest.raises(ValueError, match="Expected LLMResult"): ParagraphIndexProcessor.generate_summary( "tenant-1", diff --git a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py index b54a74b69c..450e716636 100644 --- a/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py +++ b/api/tests/unit_tests/core/rag/indexing/test_indexing_runner.py @@ -53,6 +53,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from sqlalchemy.orm.exc import ObjectDeletedError from core.errors.error import ProviderTokenNotInitError @@ -63,7 +64,6 @@ from core.indexing_runner import ( ) from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType from core.rag.models.document import ChildDocument, Document -from dify_graph.model_runtime.entities.model_entities import ModelType from libs.datetime_utils import naive_utc_now from models.dataset import Dataset, DatasetProcessRule from models.dataset import Document as DatasetDocument @@ -445,7 +445,7 @@ class TestIndexingRunnerTransform: """Mock all external dependencies for transform tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, ): yield { "db": mock_db, @@ -482,7 +482,8 @@ class TestIndexingRunnerTransform: # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [ @@ -509,7 +510,7 @@ class TestIndexingRunnerTransform: assert len(result) == 2 assert result[0].page_content == "Chunk 1" assert result[1].page_content == "Chunk 2" - runner.model_manager.get_model_instance.assert_called_once_with( + model_manager.get_model_instance.assert_called_once_with( tenant_id=sample_dataset.tenant_id, provider=sample_dataset.embedding_model_provider, model_type=ModelType.TEXT_EMBEDDING, @@ -521,6 +522,7 @@ class TestIndexingRunnerTransform: """Test transformation with economy indexing (no embeddings).""" # Arrange runner = IndexingRunner() + model_manager = mock_dependencies["model_manager"].return_value sample_dataset.indexing_technique = IndexTechniqueType.ECONOMY mock_processor = MagicMock() @@ -539,14 +541,15 @@ class TestIndexingRunnerTransform: # Assert assert len(result) == 1 - runner.model_manager.get_model_instance.assert_not_called() + model_manager.get_model_instance.assert_not_called() def test_transform_with_custom_segmentation(self, mock_dependencies, sample_dataset, sample_text_docs): """Test transformation with custom segmentation rules.""" # Arrange runner = IndexingRunner() mock_embedding_instance = MagicMock() - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() transformed_docs = [Document(page_content="Custom chunk", metadata={"doc_id": "custom1", "doc_hash": "hash1"})] @@ -586,7 +589,7 @@ class TestIndexingRunnerLoad: """Mock all external dependencies for load tests.""" with ( patch("core.indexing_runner.db") as mock_db, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.current_app") as mock_app, patch("core.indexing_runner.threading.Thread") as mock_thread, patch("core.indexing_runner.concurrent.futures.ThreadPoolExecutor") as mock_executor, @@ -645,7 +648,8 @@ class TestIndexingRunnerLoad: runner = IndexingRunner() mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 100 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -664,7 +668,7 @@ class TestIndexingRunnerLoad: runner._load(mock_processor, sample_dataset, sample_dataset_document, sample_documents) # Assert - runner.model_manager.get_model_instance.assert_called_once() + model_manager.get_model_instance.assert_called_once() # Verify executor was used for parallel processing assert mock_executor_instance.submit.called @@ -714,7 +718,8 @@ class TestIndexingRunnerLoad: mock_embedding_instance = MagicMock() mock_embedding_instance.get_text_embedding_num_tokens.return_value = 50 - runner.model_manager.get_model_instance.return_value = mock_embedding_instance + model_manager = mock_dependencies["model_manager"].return_value + model_manager.get_model_instance.return_value = mock_embedding_instance mock_processor = MagicMock() @@ -754,7 +759,7 @@ class TestIndexingRunnerRun: with ( patch("core.indexing_runner.db") as mock_db, patch("core.indexing_runner.IndexProcessorFactory") as mock_factory, - patch("core.indexing_runner.ModelManager") as mock_model_manager, + patch("core.indexing_runner.ModelManager.for_tenant") as mock_model_manager, patch("core.indexing_runner.storage") as mock_storage, patch("core.indexing_runner.threading.Thread") as mock_thread, ): diff --git a/api/tests/unit_tests/core/rag/rerank/test_reranker.py b/api/tests/unit_tests/core/rag/rerank/test_reranker.py index b150d677f1..2ec7f0498e 100644 --- a/api/tests/unit_tests/core/rag/rerank/test_reranker.py +++ b/api/tests/unit_tests/core/rag/rerank/test_reranker.py @@ -17,6 +17,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.model_runtime.entities.rerank_entities import RerankDocument, RerankResult from core.model_manager import ModelInstance from core.rag.index_processor.constant.doc_type import DocType @@ -28,7 +29,6 @@ from core.rag.rerank.rerank_factory import RerankRunnerFactory from core.rag.rerank.rerank_model import RerankModelRunner from core.rag.rerank.rerank_type import RerankMode from core.rag.rerank.weight_rerank import WeightRerankRunner -from dify_graph.model_runtime.entities.rerank_entities import RerankDocument, RerankResult def create_mock_model_instance() -> ModelInstance: @@ -57,7 +57,7 @@ class TestRerankModelRunner: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -352,12 +352,14 @@ class TestRerankModelRunner: # Assert: Empty result is returned assert len(result) == 0 - def test_user_parameter_passed_to_model(self, rerank_runner, mock_model_instance, sample_documents): - """Test that user parameter is passed to model invocation. + def test_run_uses_bound_model_instance( + self, rerank_runner, mock_model_instance, sample_documents, mock_model_manager + ): + """Test that rerank uses the bound model instance directly. Verifies: - - User ID is correctly forwarded to the model - - Model receives all expected parameters + - The injected model instance is used for invocation + - No late rebinding occurs through ModelManager.get_model_instance """ # Arrange: Mock rerank result mock_rerank_result = RerankResult( @@ -368,16 +370,18 @@ class TestRerankModelRunner: ) mock_model_instance.invoke_rerank.return_value = mock_rerank_result - # Act: Run reranking with user parameter + # Act: Run reranking result = rerank_runner.run( query="test", documents=sample_documents, - user="user123", ) - # Assert: User parameter is passed to model + # Assert: The injected model instance is invoked directly. + assert len(result) == 1 + mock_model_manager.return_value.get_model_instance.assert_not_called() call_kwargs = mock_model_instance.invoke_rerank.call_args.kwargs - assert call_kwargs["user"] == "user123" + assert call_kwargs["query"] == "test" + assert "user" not in call_kwargs class _ForwardingBaseRerankRunner(BaseRerankRunner): @@ -387,7 +391,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents: list[Document], score_threshold: float | None = None, top_n: int | None = None, - user: str | None = None, query_type: QueryType = QueryType.TEXT_QUERY, ) -> list[Document]: return super().run( @@ -395,7 +398,6 @@ class _ForwardingBaseRerankRunner(BaseRerankRunner): documents=documents, score_threshold=score_threshold, top_n=top_n, - user=user, query_type=query_type, ) @@ -424,7 +426,7 @@ class TestRerankModelRunnerMultimodal: Document(page_content="doc", metadata={"doc_id": "doc1"}, provider="dify"), ] - with patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False result = rerank_runner.run(query="image-file-id", documents=documents, query_type=QueryType.IMAGE_QUERY) @@ -441,7 +443,7 @@ class TestRerankModelRunnerMultimodal: ) with ( - patch("core.rag.rerank.rerank_model.ModelManager") as mock_mm, + patch("core.rag.rerank.rerank_model.ModelManager.for_tenant") as mock_mm, patch.object( rerank_runner, "fetch_multimodal_rerank", @@ -539,8 +541,10 @@ class TestRerankModelRunnerMultimodal: ) mock_model_instance.invoke_multimodal_rerank.return_value = rerank_result + session = MagicMock() + session.query.return_value = query_chain with ( - patch("core.rag.rerank.rerank_model.db.session.query", return_value=query_chain), + patch("core.rag.rerank.rerank_model.db.session", session), patch("core.rag.rerank.rerank_model.storage.load_once", return_value=b"query-image-bytes"), ): result, unique_documents = rerank_runner.fetch_multimodal_rerank( @@ -548,7 +552,6 @@ class TestRerankModelRunnerMultimodal: documents=[text_doc], score_threshold=0.2, top_n=2, - user="user-1", query_type=QueryType.IMAGE_QUERY, ) @@ -557,7 +560,7 @@ class TestRerankModelRunnerMultimodal: invoke_kwargs = mock_model_instance.invoke_multimodal_rerank.call_args.kwargs assert invoke_kwargs["query"]["content_type"] == DocType.IMAGE assert invoke_kwargs["docs"][0]["content"] == "text-content" - assert invoke_kwargs["user"] == "user-1" + assert "user" not in invoke_kwargs def test_fetch_multimodal_rerank_raises_when_query_image_not_found(self, rerank_runner): query_chain = Mock() @@ -595,7 +598,7 @@ class TestWeightRerankRunner: @pytest.fixture def mock_model_manager(self): """Mock ModelManager for embedding model.""" - with patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager: + with patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager: yield mock_manager @pytest.fixture @@ -1145,7 +1148,7 @@ class TestRerankIntegration: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1257,7 +1260,7 @@ class TestRerankEdgeCases: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1527,7 +1530,7 @@ class TestRerankEdgeCases: # Mock dependencies with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1598,7 +1601,7 @@ class TestRerankPerformance: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1673,7 +1676,7 @@ class TestRerankPerformance: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() @@ -1715,7 +1718,7 @@ class TestRerankErrorHandling: @pytest.fixture(autouse=True) def mock_model_manager(self): """Auto-use fixture to patch ModelManager for all tests in this class.""" - with patch("core.rag.rerank.rerank_model.ModelManager", autospec=True) as mock_mm: + with patch("core.rag.rerank.rerank_model.ModelManager.for_tenant", autospec=True) as mock_mm: mock_mm.return_value.check_model_support_vision.return_value = False yield mock_mm @@ -1824,7 +1827,7 @@ class TestRerankErrorHandling: with ( patch("core.rag.rerank.weight_rerank.JiebaKeywordTableHandler", autospec=True) as mock_jieba, - patch("core.rag.rerank.weight_rerank.ModelManager", autospec=True) as mock_manager, + patch("core.rag.rerank.weight_rerank.ModelManager.for_tenant") as mock_manager, patch("core.rag.rerank.weight_rerank.CacheEmbedding", autospec=True) as mock_cache, ): mock_handler = MagicMock() diff --git a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py index a34ca330ca..c11426163e 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_dataset_retrieval.py @@ -6,6 +6,8 @@ from uuid import uuid4 import pytest from flask import Flask, current_app +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.model_entities import ModelFeature from sqlalchemy import column from core.app.app_config.entities import ( @@ -35,9 +37,8 @@ from core.rag.retrieval.dataset_retrieval import DatasetRetrieval from core.rag.retrieval.retrieval_methods import RetrievalMethod from core.workflow.nodes.knowledge_retrieval import exc from core.workflow.nodes.knowledge_retrieval.retrieval import KnowledgeRetrievalRequest -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.model_entities import ModelFeature from models.dataset import Dataset +from models.enums import CreatorUserRole # ==================== Helper Functions ==================== @@ -3747,6 +3748,24 @@ class TestDatasetRetrievalAdditionalHelpers: mock_session.add_all.assert_called() mock_session.commit.assert_called() + def test_on_query_normalizes_workflow_end_user_role(self, retrieval: DatasetRetrieval) -> None: + with patch("core.rag.retrieval.dataset_retrieval.db.session") as mock_session: + retrieval._on_query( + query="python", + attachment_ids=None, + dataset_ids=["d1"], + app_id="a1", + user_from="end-user", + user_id="u1", + ) + + mock_session.add_all.assert_called_once() + added_queries = mock_session.add_all.call_args.args[0] + + assert len(added_queries) == 1 + assert added_queries[0].created_by_role == CreatorUserRole.END_USER + mock_session.commit.assert_called_once() + def test_handle_invoke_result(self, retrieval: DatasetRetrieval) -> None: usage = LLMUsage.empty_usage() chunk_1 = SimpleNamespace( @@ -3836,7 +3855,7 @@ class TestDatasetRetrievalAdditionalHelpers: model_instance.model_type_instance.get_model_schema.return_value = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_manager, patch("core.rag.retrieval.dataset_retrieval.ModelConfigWithCredentialsEntity") as mock_cfg_entity, ): mock_manager.return_value.get_model_instance.return_value = model_instance @@ -4222,11 +4241,12 @@ class TestKnowledgeRetrievalCoverage: with ( patch.object(retrieval, "_check_knowledge_rate_limit"), patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="dataset-1")]), - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, ): mock_model_manager.return_value.get_model_instance.return_value = model_instance with pytest.raises(Exception) as exc_info: retrieval.knowledge_retrieval(request) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert error_cls in type(exc_info.value).__name__ @@ -4279,9 +4299,13 @@ class TestRetrieveCoverage: ), ) model_config = self._build_model_config() - model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None - with patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager: - mock_model_manager.return_value.get_model_instance.return_value = Mock() + model_instance = Mock() + model_instance.model_name = "gpt-4" + model_instance.credentials = {"api_key": "secret"} + model_instance.provider_model_bundle = Mock() + model_instance.model_type_instance.get_model_schema.return_value = None + with patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager: + mock_model_manager.return_value.get_model_instance.return_value = model_instance result = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4294,8 +4318,58 @@ class TestRetrieveCoverage: hit_callback=Mock(), message_id="m1", ) + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") assert result == (None, []) + def test_retrieve_uses_bound_model_instance_schema_and_updates_model_config( + self, retrieval: DatasetRetrieval + ) -> None: + config = DatasetEntity( + dataset_ids=["d1"], + retrieve_config=DatasetRetrieveConfigEntity( + retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, + metadata_filtering_mode="disabled", + ), + ) + model_config = self._build_model_config(features=[]) + model_config.provider_model_bundle.model_type_instance.get_model_schema.return_value = None + bound_schema = SimpleNamespace(features=[ModelFeature.TOOL_CALL]) + bound_bundle = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {"api_key": "secret"} + bound_model_instance.provider_model_bundle = bound_bundle + bound_model_instance.model_type_instance.get_model_schema.return_value = bound_schema + + with ( + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, + patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), + patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), + patch.object(retrieval, "single_retrieve", return_value=[]) as mock_single_retrieve, + ): + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance + context, files = retrieval.retrieve( + app_id="app-1", + user_id="user-1", + tenant_id="tenant-1", + model_config=model_config, + config=config, + query="python", + invoke_from=InvokeFrom.WEB_APP, + show_retrieve_source=False, + hit_callback=Mock(), + message_id="m1", + ) + + mock_model_manager.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + mock_single_retrieve.assert_called_once() + assert mock_single_retrieve.call_args.args[8] == PlanningStrategy.ROUTER + assert model_config.provider_model_bundle is bound_bundle + assert model_config.credentials == {"api_key": "secret"} + assert model_config.model_schema is bound_schema + assert context == "" + assert files == [] + def test_single_strategy_with_external_documents(self, retrieval: DatasetRetrieval) -> None: retrieve_config = DatasetRetrieveConfigEntity( retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE, @@ -4312,12 +4386,17 @@ class TestRetrieveCoverage: extra={"title": "External", "dataset_name": "External DS"}, ) with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "single_retrieve", return_value=[external_doc]), ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace(features=[]) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance context, files = retrieval.retrieve( app_id="app-1", user_id="user-1", @@ -4402,7 +4481,7 @@ class TestRetrieveCoverage: hit_callback = Mock() with ( - patch("core.rag.retrieval.dataset_retrieval.ModelManager") as mock_model_manager, + patch("core.rag.retrieval.dataset_retrieval.ModelManager.for_tenant") as mock_model_manager, patch.object(retrieval, "_get_available_datasets", return_value=[SimpleNamespace(id="d1")]), patch.object(retrieval, "get_metadata_filter_condition", return_value=(None, None)), patch.object(retrieval, "multiple_retrieve", return_value=[external_doc, dify_doc]), @@ -4413,7 +4492,14 @@ class TestRetrieveCoverage: patch("core.rag.retrieval.dataset_retrieval.sign_upload_file", return_value="https://signed"), patch("core.rag.retrieval.dataset_retrieval.db.session.execute") as mock_execute, ): - mock_model_manager.return_value.get_model_instance.return_value = Mock() + bound_model_instance = Mock() + bound_model_instance.model_name = "gpt-4" + bound_model_instance.credentials = {} + bound_model_instance.provider_model_bundle = Mock() + bound_model_instance.model_type_instance.get_model_schema.return_value = SimpleNamespace( + features=[ModelFeature.TOOL_CALL] + ) + mock_model_manager.return_value.get_model_instance.return_value = bound_model_instance mock_execute.side_effect = [execute_attachments, execute_docs, execute_datasets] context, files = retrieval.retrieve( app_id="app-1", diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py index cfa9094e12..5a2ecb8220 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_function_call_router.py @@ -1,7 +1,8 @@ from unittest.mock import Mock +from graphon.model_runtime.entities.llm_entities import LLMUsage + from core.rag.retrieval.router.multi_dataset_function_call_router import FunctionCallMultiDatasetRouter -from dify_graph.model_runtime.entities.llm_entities import LLMUsage class TestFunctionCallMultiDatasetRouter: diff --git a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py index e429563739..539ac0f849 100644 --- a/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py +++ b/api/tests/unit_tests/core/rag/retrieval/test_multi_dataset_react_route.py @@ -1,10 +1,12 @@ from types import SimpleNamespace from unittest.mock import Mock, patch +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.model_runtime.entities.message_entities import PromptMessageRole +from graphon.model_runtime.entities.model_entities import ModelType + from core.rag.retrieval.output_parser.react_output import ReactAction, ReactFinish from core.rag.retrieval.router.multi_dataset_react_route import ReactMultiDatasetRouter -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole class TestReactMultiDatasetRouter: @@ -87,6 +89,7 @@ class TestReactMultiDatasetRouter: model_config = Mock() model_config.mode = "chat" model_config.parameters = {"temperature": 0.1} + model_instance = Mock() usage = LLMUsage.empty_usage() tools = [Mock(name="dataset-1"), Mock(name="dataset-2")] tools[0].name = "dataset-1" @@ -108,13 +111,14 @@ class TestReactMultiDatasetRouter: dataset_id, returned_usage = router._react_invoke( query="python", model_config=model_config, - model_instance=Mock(), + model_instance=model_instance, tools=tools, user_id="u1", tenant_id="t1", ) mock_chat_prompt.assert_called_once() + assert mock_prompt_transform.return_value.get_prompt.call_args.kwargs["model_instance"] is model_instance assert dataset_id == "dataset-2" assert returned_usage == usage @@ -162,7 +166,11 @@ class TestReactMultiDatasetRouter: model_instance = Mock() model_instance.invoke_llm.return_value = iter([chunk]) - with patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct: + with ( + patch("core.rag.retrieval.router.multi_dataset_react_route.ModelManager.for_tenant") as mock_manager, + patch("core.rag.retrieval.router.multi_dataset_react_route.deduct_llm_quota") as mock_deduct, + ): + mock_manager.return_value.get_model_instance.return_value = model_instance text, returned_usage = router._invoke_llm( completion_param={"temperature": 0.1}, model_instance=model_instance, @@ -174,6 +182,13 @@ class TestReactMultiDatasetRouter: assert text == "part" assert returned_usage == usage + mock_manager.assert_called_once_with(tenant_id="t1", user_id="u1") + mock_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id="t1", + provider=model_instance.provider, + model_type=ModelType.LLM, + model=model_instance.model_name, + ) mock_deduct.assert_called_once() def test_handle_invoke_result_with_empty_usage(self) -> None: diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py index e7eecfa297..e229d5fc1a 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_execution_repository.py @@ -9,9 +9,10 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowType from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowType from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py index 2a83a4e802..7dbf78d0f0 100644 --- a/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_celery_workflow_node_execution_repository.py @@ -9,14 +9,14 @@ from unittest.mock import Mock, patch from uuid import uuid4 import pytest - -from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository -from dify_graph.entities.workflow_node_execution import ( +from graphon.entities.workflow_node_execution import ( WorkflowNodeExecution, WorkflowNodeExecutionStatus, ) -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig +from graphon.enums import BuiltinNodeTypes + +from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository +from core.repositories.factory import OrderConfig from libs.datetime_utils import naive_utc_now from models import Account, EndUser from models.workflow import WorkflowNodeExecutionTriggeredFrom @@ -181,10 +181,10 @@ class TestCeleryWorkflowNodeExecutionRepository: repo.save(sample_workflow_node_execution) @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") - def test_get_by_workflow_run_from_cache( + def test_get_by_workflow_execution_from_cache( self, mock_task, mock_session_factory, mock_account, sample_workflow_node_execution ): - """Test that get_by_workflow_run retrieves executions from cache.""" + """Test that get_by_workflow_execution retrieves executions from cache.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -195,18 +195,18 @@ class TestCeleryWorkflowNodeExecutionRepository: # Save execution to cache first repo.save(sample_workflow_node_execution) - workflow_run_id = sample_workflow_node_execution.workflow_execution_id + workflow_execution_id = sample_workflow_node_execution.workflow_execution_id order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) # Verify results were retrieved from cache assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id assert result[0] is sample_workflow_node_execution - def test_get_by_workflow_run_without_order_config(self, mock_session_factory, mock_account): - """Test get_by_workflow_run without order configuration.""" + def test_get_by_workflow_execution_without_order_config(self, mock_session_factory, mock_account): + """Test get_by_workflow_execution without order configuration.""" repo = CeleryWorkflowNodeExecutionRepository( session_factory=mock_session_factory, user=mock_account, @@ -214,7 +214,7 @@ class TestCeleryWorkflowNodeExecutionRepository: triggered_from=WorkflowNodeExecutionTriggeredFrom.WORKFLOW_RUN, ) - result = repo.get_by_workflow_run("workflow-run-id") + result = repo.get_by_workflow_execution("workflow-run-id") # Should return empty list since nothing in cache assert len(result) == 0 @@ -236,7 +236,7 @@ class TestCeleryWorkflowNodeExecutionRepository: assert sample_workflow_node_execution.id in repo._execution_cache # Test retrieving from cache - result = repo.get_by_workflow_run(sample_workflow_node_execution.workflow_execution_id) + result = repo.get_by_workflow_execution(sample_workflow_node_execution.workflow_execution_id) assert len(result) == 1 assert result[0].id == sample_workflow_node_execution.id @@ -251,12 +251,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create multiple executions for the same workflow - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.START, @@ -269,7 +269,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.LLM, @@ -285,10 +285,10 @@ class TestCeleryWorkflowNodeExecutionRepository: # Verify both are cached and mapped assert len(repo._execution_cache) == 2 - assert len(repo._workflow_execution_mapping[workflow_run_id]) == 2 + assert len(repo._workflow_execution_mapping[workflow_execution_id]) == 2 # Test retrieval - result = repo.get_by_workflow_run(workflow_run_id) + result = repo.get_by_workflow_execution(workflow_execution_id) assert len(result) == 2 @patch("core.repositories.celery_workflow_node_execution_repository.save_workflow_node_execution_task") @@ -302,12 +302,12 @@ class TestCeleryWorkflowNodeExecutionRepository: ) # Create executions with different indices - workflow_run_id = str(uuid4()) + workflow_execution_id = str(uuid4()) exec1 = WorkflowNodeExecution( id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=2, node_id="node2", node_type=BuiltinNodeTypes.START, @@ -320,7 +320,7 @@ class TestCeleryWorkflowNodeExecutionRepository: id=str(uuid4()), node_execution_id=str(uuid4()), workflow_id=str(uuid4()), - workflow_execution_id=workflow_run_id, + workflow_execution_id=workflow_execution_id, index=1, node_id="node1", node_type=BuiltinNodeTypes.LLM, @@ -336,14 +336,14 @@ class TestCeleryWorkflowNodeExecutionRepository: # Test ascending order order_config = OrderConfig(order_by=["index"], order_direction="asc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 1 assert result[1].index == 2 # Test descending order order_config = OrderConfig(order_by=["index"], order_direction="desc") - result = repo.get_by_workflow_run(workflow_run_id, order_config) + result = repo.get_by_workflow_execution(workflow_execution_id, order_config) assert len(result) == 2 assert result[0].index == 2 assert result[1].index == 1 diff --git a/api/tests/unit_tests/core/repositories/test_factory.py b/api/tests/unit_tests/core/repositories/test_factory.py index fe9eed0307..48327c3913 100644 --- a/api/tests/unit_tests/core/repositories/test_factory.py +++ b/api/tests/unit_tests/core/repositories/test_factory.py @@ -11,9 +11,12 @@ import pytest from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker -from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError -from dify_graph.repositories.workflow_execution_repository import WorkflowExecutionRepository -from dify_graph.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository +from core.repositories.factory import ( + DifyCoreRepositoryFactory, + RepositoryImportError, + WorkflowExecutionRepository, + WorkflowNodeExecutionRepository, +) from libs.module_loading import import_string from models import Account, EndUser from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py index 9af4d12664..0fc82dda53 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_form_repository_impl.py @@ -7,6 +7,11 @@ from datetime import datetime from types import SimpleNamespace import pytest +from graphon.nodes.human_input.entities import ( + FormDefinition, + UserAction, +) +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( HumanInputFormRecord, @@ -14,16 +19,13 @@ from core.repositories.human_input_repository import ( HumanInputFormSubmissionRepository, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - FormDefinition, MemberRecipient, - UserAction, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from libs.datetime_utils import naive_utc_now from models.human_input import ( EmailExternalRecipientPayload, @@ -89,9 +91,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="external@example.com"), ], ), @@ -125,9 +127,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="missing-member"), + MemberRecipient(reference_id="missing-member"), ExternalRecipient(email="external@example.com"), ], ), @@ -156,7 +158,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[], ), ) @@ -182,7 +184,7 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ ExternalRecipient(email="external@example.com"), ExternalRecipient(email="external@example.com"), @@ -212,9 +214,9 @@ class TestHumanInputFormRepositoryImplHelpers: form_id="form-id", delivery_id="delivery-id", recipients_config=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(user_id="member-1"), + MemberRecipient(reference_id="member-1"), ExternalRecipient(email="shared@example.com"), ], ), @@ -243,7 +245,7 @@ class TestHumanInputFormRepositoryImplHelpers: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=True, + include_bound_group=True, items=[ExternalRecipient(email="external@example.com")], ), subject="subject", @@ -272,7 +274,7 @@ def _make_form_definition() -> str: inputs=[], user_actions=[UserAction(id="submit", title="Submit")], rendered_content="

hello

", - expiration_time=datetime.utcnow(), + expiration_time=naive_utc_now(), ).model_dump_json() @@ -421,22 +423,22 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.id == form.id - assert entity.web_app_token == "token-123" + assert entity.submission_token == "token-123" assert len(entity.recipients) == 1 assert entity.recipients[0].token == "token-123" def test_get_form_returns_none_when_missing(self, monkeypatch: pytest.MonkeyPatch): session = _FakeSession(scalars_results=[None]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id="run-1") - assert repo.get_form("run-1", "node-1") is None + assert repo.get_form("node-1") is None def test_get_form_returns_unsubmitted_state(self, monkeypatch: pytest.MonkeyPatch): form = _DummyForm( @@ -451,9 +453,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is False @@ -476,9 +478,9 @@ class TestHumanInputFormRepositoryImplPublicMethods: ) session = _FakeSession(scalars_results=[form, []]) _patch_repo_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant-id", workflow_execution_id=form.workflow_run_id) - entity = repo.get_form(form.workflow_run_id, form.node_id) + entity = repo.get_form(form.node_id) assert entity is not None assert entity.submitted is True diff --git a/api/tests/unit_tests/core/repositories/test_human_input_repository.py b/api/tests/unit_tests/core/repositories/test_human_input_repository.py index 4116e8b4a5..8ff0e40587 100644 --- a/api/tests/unit_tests/core/repositories/test_human_input_repository.py +++ b/api/tests/unit_tests/core/repositories/test_human_input_repository.py @@ -9,8 +9,12 @@ from typing import Any from unittest.mock import MagicMock import pytest +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from core.repositories.human_input_repository import ( + FormCreateParams, + FormNotFoundError, HumanInputFormRecord, HumanInputFormRepositoryImpl, HumanInputFormSubmissionRepository, @@ -19,18 +23,14 @@ from core.repositories.human_input_repository import ( _InvalidTimeoutStatusError, _WorkspaceMemberInfo, ) -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, - UserAction, WebAppDeliveryMethod, ) -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import FormCreateParams, FormNotFoundError from libs.datetime_utils import naive_utc_now from models.human_input import HumanInputFormRecipient, RecipientType @@ -212,7 +212,7 @@ def test_recipient_entity_id_and_token_success() -> None: assert entity.token == "tok" -def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> None: +def test_form_entity_submission_token_prefers_console_then_webapp_then_none() -> None: form = _DummyForm( id="f1", workflow_run_id="run", @@ -229,13 +229,13 @@ def test_form_entity_web_app_token_prefers_console_then_webapp_then_none() -> No ) entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp, console]) # type: ignore[arg-type] - assert entity.web_app_token == "ctok" + assert entity.submission_token == "ctok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[webapp]) # type: ignore[arg-type] - assert entity.web_app_token == "wtok" + assert entity.submission_token == "wtok" entity = _HumanInputFormEntityImpl(form_model=form, recipient_models=[]) # type: ignore[arg-type] - assert entity.web_app_token is None + assert entity.submission_token is None def test_form_entity_submitted_data_parsed() -> None: @@ -364,8 +364,8 @@ def test_delivery_method_to_model_email_uses_build_email_recipients(monkeypatch: method = EmailDeliveryMethod( config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), subject="s", body="b", @@ -388,7 +388,7 @@ def test_build_email_recipients_uses_all_members_when_whole_workspace(monkeypatc session=MagicMock(), form_id="f", delivery_id="d", - recipients_config=EmailRecipients(whole_workspace=True, items=[ExternalRecipient(email="e@example.com")]), + recipients_config=EmailRecipients(include_bound_group=True, items=[ExternalRecipient(email="e@example.com")]), ) assert recipients == ["ok"] @@ -407,8 +407,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m form_id="f", delivery_id="d", recipients_config=EmailRecipients( - whole_workspace=False, - items=[MemberRecipient(user_id="u1"), ExternalRecipient(email="e@example.com")], + include_bound_group=False, + items=[MemberRecipient(reference_id="u1"), ExternalRecipient(email="e@example.com")], ), ) assert recipients == ["ok"] @@ -416,8 +416,8 @@ def test_build_email_recipients_uses_selected_members_when_not_whole_workspace(m def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.MonkeyPatch) -> None: _patch_session_factory(monkeypatch, _FakeSession(scalars_results=[None])) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - assert repo.get_form("run", "node") is None + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + assert repo.get_form("node") is None form = _DummyForm( id="f1", @@ -437,8 +437,8 @@ def test_get_form_returns_entity_and_none_when_missing(monkeypatch: pytest.Monke ) session = _FakeSession(scalars_results=[form, [recipient]]) _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") - entity = repo.get_form("run", "node") + repo = HumanInputFormRepositoryImpl(tenant_id="tenant", workflow_execution_id="run") + entity = repo.get_form("node") assert entity is not None assert entity.id == "f1" assert entity.recipients[0].id == "r1" @@ -454,7 +454,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M session = _FakeSession() _patch_session_factory(monkeypatch, session) - repo = HumanInputFormRepositoryImpl(tenant_id="tenant") + repo = HumanInputFormRepositoryImpl( + tenant_id="tenant", + app_id="app", + workflow_execution_id="run", + invoke_source="debugger", + submission_actor_id="acc-1", + ) form_config = HumanInputNodeData( title="Title", @@ -464,8 +470,7 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M user_actions=[UserAction(id="submit", title="Submit")], ) params = FormCreateParams( - app_id="app", - workflow_execution_id="run", + workflow_execution_id=None, node_id="node", form_config=form_config, rendered_content="

hello

", @@ -473,16 +478,13 @@ def test_create_form_adds_console_and_backstage_recipients(monkeypatch: pytest.M display_in_ui=True, resolved_default_values={}, form_kind=HumanInputFormKind.RUNTIME, - console_recipient_required=True, - console_creator_account_id="acc-1", - backstage_recipient_required=True, ) entity = repo.create_form(params) assert entity.id == "form-id" assert entity.expiration_time == fixed_now + timedelta(hours=form_config.timeout) # Console token should take precedence when console recipient is present. - assert entity.web_app_token == "token-console" + assert entity.submission_token == "token-console" assert len(entity.recipients) == 3 diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py index 232ab07882..e5c3e85487 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_execution_repository.py @@ -3,11 +3,12 @@ from unittest.mock import MagicMock from uuid import uuid4 import pytest +from graphon.entities import WorkflowExecution +from graphon.enums import WorkflowExecutionStatus, WorkflowType from sqlalchemy.engine import Engine from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository -from dify_graph.entities.workflow_execution import WorkflowExecution, WorkflowExecutionStatus, WorkflowType from models import Account, CreatorUserRole, EndUser, WorkflowRun from models.enums import WorkflowRunTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py index 73de15e2cf..5b4d26b780 100644 --- a/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py +++ b/api/tests/unit_tests/core/repositories/test_sqlalchemy_workflow_node_execution_repository.py @@ -10,11 +10,18 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( + BuiltinNodeTypes, + WorkflowNodeExecutionMetadataKey, + WorkflowNodeExecutionStatus, +) from sqlalchemy import Engine, create_engine from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from configs import dify_config +from core.repositories.factory import OrderConfig from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, _deterministic_json_dump, @@ -22,13 +29,6 @@ from core.repositories.sqlalchemy_workflow_node_execution_repository import ( _find_first, _replace_or_append_offload, ) -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import ( - BuiltinNodeTypes, - WorkflowNodeExecutionMetadataKey, - WorkflowNodeExecutionStatus, -) -from dify_graph.repositories.workflow_node_execution_repository import OrderConfig from models import Account, EndUser from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload, WorkflowNodeExecutionTriggeredFrom @@ -768,5 +768,5 @@ def test_get_by_workflow_run_maps_to_domain(monkeypatch: pytest.MonkeyPatch) -> lambda max_workers: FakeExecutor(), ) - result = repo.get_by_workflow_run("run", order_config=None) + result = repo.get_by_workflow_execution("run", order_config=None) assert result == ["domain:db1", "domain:db2"] diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py index 456c3dde12..84fe522388 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_conflict_handling.py @@ -4,17 +4,17 @@ from unittest.mock import MagicMock, Mock import psycopg2.errors import pytest +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import sessionmaker from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from dify_graph.enums import BuiltinNodeTypes from libs.datetime_utils import naive_utc_now from models import Account, WorkflowNodeExecutionTriggeredFrom diff --git a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py index eeab81a178..27729e7f06 100644 --- a/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py +++ b/api/tests/unit_tests/core/repositories/test_workflow_node_execution_truncation.py @@ -11,17 +11,17 @@ from datetime import UTC, datetime from typing import Any from unittest.mock import MagicMock +from graphon.entities.workflow_node_execution import ( + WorkflowNodeExecution, + WorkflowNodeExecutionStatus, +) +from graphon.enums import BuiltinNodeTypes from sqlalchemy import Engine from configs import dify_config from core.repositories.sqlalchemy_workflow_node_execution_repository import ( SQLAlchemyWorkflowNodeExecutionRepository, ) -from dify_graph.entities.workflow_node_execution import ( - WorkflowNodeExecution, - WorkflowNodeExecutionStatus, -) -from dify_graph.enums import BuiltinNodeTypes from models import Account, WorkflowNodeExecutionTriggeredFrom from models.enums import ExecutionOffLoadType from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload diff --git a/api/tests/unit_tests/core/telemetry/test_facade.py b/api/tests/unit_tests/core/telemetry/test_facade.py new file mode 100644 index 0000000000..36e8e1bbb1 --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_facade.py @@ -0,0 +1,181 @@ +"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering.""" + +from __future__ import annotations + +import queue +import sys +import types +from unittest.mock import MagicMock, patch + +import pytest + +from core.ops.entities.trace_entity import TraceTaskName +from core.telemetry.events import TelemetryContext, TelemetryEvent + + +@pytest.fixture +def telemetry_test_setup(monkeypatch): + module_name = "core.ops.ops_trace_manager" + ops_stub = types.ModuleType(module_name) + + class StubTraceTask: + def __init__(self, trace_type, **kwargs): + self.trace_type = trace_type + self.app_id = None + self.kwargs = kwargs + + class StubTraceQueueManager: + def __init__(self, app_id=None, user_id=None): + self.app_id = app_id + self.user_id = user_id + self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id) + + def add_trace_task(self, trace_task): + trace_task.app_id = self.app_id + from core.ops.ops_trace_manager import trace_manager_queue + + trace_manager_queue.put(trace_task) + + class StubOpsTraceManager: + @staticmethod + def get_ops_trace_instance(app_id): + return None + + ops_stub.TraceQueueManager = StubTraceQueueManager + ops_stub.TraceTask = StubTraceTask + ops_stub.OpsTraceManager = StubOpsTraceManager + ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue) + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + from core.telemetry import emit + + return emit, ops_stub.trace_manager_queue + + +class TestTelemetryEmit: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_enterprise_trace_creates_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"key": "value"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + + def test_emit_community_trace_enqueued(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.WORKFLOW_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + + def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_not_called() + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + enterprise_only_traces = [ + TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + TraceTaskName.NODE_EXECUTION_TRACE, + TraceTaskName.PROMPT_GENERATION_TRACE, + ] + + for trace_name in enterprise_only_traces: + mock_queue.reset_mock() + + event = TelemetryEvent( + name=trace_name, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == trace_name + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_passes_name_directly_to_trace_task(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + event = TelemetryEvent( + name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={"extra": "data"}, + ) + + emit_fn(event) + + mock_queue.put.assert_called_once() + called_task = mock_queue.put.call_args[0][0] + assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE + assert isinstance(called_task.trace_type, TraceTaskName) + + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_emit_with_provided_trace_manager(self, mock_ee, telemetry_test_setup): + emit_fn, mock_queue = telemetry_test_setup + + mock_trace_manager = MagicMock() + mock_trace_manager.add_trace_task = MagicMock() + + event = TelemetryEvent( + name=TraceTaskName.NODE_EXECUTION_TRACE, + context=TelemetryContext( + tenant_id="test-tenant", + user_id="test-user", + app_id="test-app", + ), + payload={}, + ) + + emit_fn(event, trace_manager=mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + called_task = mock_trace_manager.add_trace_task.call_args[0][0] + assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE diff --git a/api/tests/unit_tests/core/telemetry/test_gateway_integration.py b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py new file mode 100644 index 0000000000..a68fce5e7f --- /dev/null +++ b/api/tests/unit_tests/core/telemetry/test_gateway_integration.py @@ -0,0 +1,225 @@ +from __future__ import annotations + +import sys +from unittest.mock import MagicMock, patch + +import pytest + +from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled +from enterprise.telemetry.contracts import TelemetryCase + + +class TestTelemetryCoreExports: + def test_is_enterprise_telemetry_enabled_exported(self) -> None: + from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func + + assert callable(exported_func) + + +@pytest.fixture +def mock_ops_trace_manager(): + mock_module = MagicMock() + mock_trace_task_class = MagicMock() + mock_trace_task_class.return_value = MagicMock() + mock_module.TraceTask = mock_trace_task_class + mock_module.TraceQueueManager = MagicMock() + + mock_trace_entity = MagicMock() + mock_trace_task_name = MagicMock() + mock_trace_task_name.return_value = "workflow" + mock_trace_entity.TraceTaskName = mock_trace_task_name + + with ( + patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}), + patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}), + ): + yield mock_module, mock_trace_entity + + +class TestGatewayIntegrationTraceRouting: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_to_trace_manager( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_ce_eligible_trace_routed_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_dropped_when_ee_disabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_enterprise_only_trace_routed_when_ee_enabled( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationMetricRouting: + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_metric_case_routes_to_celery_task( + self, + mock_ee_enabled: MagicMock, + ) -> None: + from enterprise.telemetry.contracts import TelemetryEnvelope + + with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay: + context = {"tenant_id": "tenant-123"} + payload = {"app_id": "app-abc", "name": "My App"} + + emit(TelemetryCase.APP_CREATED, context, payload) + + mock_delay.assert_called_once() + envelope_json = mock_delay.call_args[0][0] + envelope = TelemetryEnvelope.model_validate_json(envelope_json) + assert envelope.case == TelemetryCase.APP_CREATED + assert envelope.tenant_id == "tenant-123" + assert envelope.payload["app_id"] == "app-abc" + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_tool_execution_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"} + + emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + @patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True) + def test_moderation_check_trace_routed( + self, + mock_ee_enabled: MagicMock, + ) -> None: + mock_trace_manager = MagicMock() + context = {"tenant_id": "tenant-123", "app_id": "app-123"} + payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}} + + emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + +class TestGatewayIntegrationCEEligibility: + @pytest.fixture + def mock_trace_manager(self) -> MagicMock: + return MagicMock() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_workflow_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"workflow_run_id": "run-abc"} + + emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_message_run_is_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"message_id": "msg-abc", "conversation_id": "conv-123"} + + emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_called_once() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_id": "node-abc"} + + emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_draft_node_execution_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456"} + payload = {"node_execution_data": {}} + + emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + @pytest.mark.usefixtures("mock_ops_trace_manager") + def test_prompt_generation_not_ce_eligible( + self, + mock_trace_manager: MagicMock, + ) -> None: + with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False): + context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"} + payload = {"operation_type": "generate", "instruction": "test"} + + emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager) + + mock_trace_manager.add_trace_task.assert_not_called() + + +class TestIsEnterpriseTelemetryEnabled: + def test_returns_false_when_exporter_import_fails(self) -> None: + with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}): + result = is_enterprise_telemetry_enabled() + assert result is False + + def test_function_is_callable(self) -> None: + assert callable(is_enterprise_telemetry_enabled) diff --git a/api/tests/unit_tests/core/test_file.py b/api/tests/unit_tests/core/test_file.py index 251d6fd25e..ac65d0c02b 100644 --- a/api/tests/unit_tests/core/test_file.py +++ b/api/tests/unit_tests/core/test_file.py @@ -1,6 +1,7 @@ import json -from dify_graph.file import File, FileTransferMethod, FileType, FileUploadConfig +from graphon.file import File, FileTransferMethod, FileType, FileUploadConfig + from models.workflow import Workflow diff --git a/api/tests/unit_tests/core/test_model_manager.py b/api/tests/unit_tests/core/test_model_manager.py index 92e4b58473..f5efb78b61 100644 --- a/api/tests/unit_tests/core/test_model_manager.py +++ b/api/tests/unit_tests/core/test_model_manager.py @@ -2,11 +2,11 @@ from unittest.mock import MagicMock, patch import pytest import redis +from graphon.model_runtime.entities.model_entities import ModelType from pytest_mock import MockerFixture from core.entities.provider_entities import ModelLoadBalancingConfiguration from core.model_manager import LBModelManager -from dify_graph.model_runtime.entities.model_entities import ModelType from extensions.ext_redis import redis_client diff --git a/api/tests/unit_tests/core/test_provider_configuration.py b/api/tests/unit_tests/core/test_provider_configuration.py index 90ed1647aa..331166fe63 100644 --- a/api/tests/unit_tests/core/test_provider_configuration.py +++ b/api/tests/unit_tests/core/test_provider_configuration.py @@ -1,6 +1,15 @@ from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( + ConfigurateMethod, + CredentialFormSchema, + FormOption, + FormType, + ProviderEntity, +) from core.entities.provider_configuration import ProviderConfiguration, SystemConfigurationStatus from core.entities.provider_entities import ( @@ -12,15 +21,6 @@ from core.entities.provider_entities import ( RestrictModel, SystemConfiguration, ) -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( - ConfigurateMethod, - CredentialFormSchema, - FormOption, - FormType, - ProviderEntity, -) from models.provider import Provider, ProviderType diff --git a/api/tests/unit_tests/core/test_provider_manager.py b/api/tests/unit_tests/core/test_provider_manager.py index 69567c54eb..259cb5fdd0 100644 --- a/api/tests/unit_tests/core/test_provider_manager.py +++ b/api/tests/unit_tests/core/test_provider_manager.py @@ -1,12 +1,26 @@ -from unittest.mock import Mock, PropertyMock, patch +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, PropertyMock, patch import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from pytest_mock import MockerFixture from core.entities.provider_entities import ModelSettings from core.provider_manager import ProviderManager -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from models.provider import LoadBalancingModelConfig, ProviderModelSetting +from models.provider import LoadBalancingModelConfig, ProviderModelSetting, TenantDefaultModel +from models.provider_ids import ModelProviderID + + +def _build_provider_manager(mocker: MockerFixture) -> ProviderManager: + return ProviderManager(model_runtime=mocker.Mock()) + + +def _build_session_context(session: Mock) -> MagicMock: + session_cm = MagicMock() + session_cm.__enter__.return_value = session + session_cm.__exit__.return_value = False + return session_cm @pytest.fixture @@ -28,7 +42,7 @@ def mock_provider_entity(): return mock_entity -def test__to_model_settings(mock_provider_entity): +def test__to_model_settings(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -69,7 +83,7 @@ def test__to_model_settings(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -89,7 +103,7 @@ def test__to_model_settings(mock_provider_entity): assert result[0].load_balancing_configs[1].name == "first" -def test__to_model_settings_only_one_lb(mock_provider_entity): +def test__to_model_settings_only_one_lb(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( @@ -119,7 +133,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -137,7 +151,7 @@ def test__to_model_settings_only_one_lb(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test__to_model_settings_lb_disabled(mock_provider_entity): +def test__to_model_settings_lb_disabled(mocker: MockerFixture, mock_provider_entity): # Mocking the inputs ps = ProviderModelSetting( tenant_id="tenant_id", @@ -176,7 +190,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): "core.helper.model_provider_cache.ProviderCredentialsCache.get", return_value={"openai_api_key": "fake_key"}, ): - provider_manager = ProviderManager() + provider_manager = _build_provider_manager(mocker) # Running the method result = provider_manager._to_model_settings( @@ -194,7 +208,7 @@ def test__to_model_settings_lb_disabled(mock_provider_entity): assert len(result[0].load_balancing_configs) == 0 -def test_get_default_model_uses_first_available_active_model(): +def test_get_default_model_uses_first_available_active_model(mocker: MockerFixture): mock_session = Mock() mock_session.scalar.return_value = None @@ -204,7 +218,7 @@ def test_get_default_model_uses_first_available_active_model(): Mock(model="gpt-4", provider=Mock(provider="openai")), ] - manager = ProviderManager() + manager = _build_provider_manager(mocker) with ( patch("core.provider_manager.db.session", mock_session), patch.object(manager, "get_configurations", return_value=provider_configurations), @@ -228,3 +242,345 @@ def test_get_default_model_uses_first_available_active_model(): assert saved_default_model.model_name == "gpt-3.5-turbo" assert saved_default_model.provider_name == "openai" mock_session.commit.assert_called_once() + + +def test_get_default_model_returns_none_when_no_default_or_active_models(mocker: MockerFixture): + mock_session = Mock() + mock_session.scalar.return_value = None + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + result = manager.get_default_model("tenant-id", ModelType.LLM) + + assert result is None + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + mock_factory_cls.assert_not_called() + mock_session.add.assert_not_called() + mock_session.commit.assert_not_called() + + +def test_get_default_model_uses_injected_runtime_for_existing_default_record(mocker: MockerFixture): + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="openai", + model_name="gpt-4", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + manager = _build_provider_manager(mocker) + + with ( + patch("core.provider_manager.db.session", mock_session), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_provider_schema.return_value = Mock( + provider="openai", + label=I18nObject(en_US="OpenAI", zh_Hans="OpenAI"), + icon_small=I18nObject(en_US="icon_small.png", zh_Hans="icon_small.png"), + supported_model_types=[ModelType.LLM], + ) + + result = manager.get_default_model("tenant-id", ModelType.LLM) + + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result is not None + assert result.model == "gpt-4" + assert result.provider.provider == "openai" + + +def test_get_configurations_uses_injected_runtime_and_adds_provider_aliases(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_records = {"openai": [SimpleNamespace(provider_name="openai")]} + provider_model_records = {"openai": [SimpleNamespace(provider_name="openai")]} + preferred_provider_records = {"openai": SimpleNamespace(preferred_provider_type="system")} + + with ( + patch.object(manager, "_get_all_providers", return_value=provider_records), + patch.object(manager, "_init_trial_provider_records", return_value=provider_records), + patch.object(manager, "_get_all_provider_models", return_value=provider_model_records), + patch.object(manager, "_get_all_preferred_model_providers", return_value=preferred_provider_records), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch("core.provider_manager.ModelProviderFactory") as mock_factory_cls, + ): + mock_factory_cls.return_value.get_providers.return_value = [] + + result = manager.get_configurations("tenant-id") + + expected_alias = str(ModelProviderID("openai")) + mock_factory_cls.assert_called_once_with(model_runtime=manager._model_runtime) + assert result.tenant_id == "tenant-id" + assert expected_alias in provider_records + assert expected_alias in provider_model_records + assert expected_alias in preferred_provider_records + + +@pytest.mark.parametrize( + ("provider_name", "expected_provider_names"), + [ + ("openai", ["openai", "langgenius/openai/openai"]), + ("langgenius/openai/openai", ["langgenius/openai/openai", "openai"]), + ("langgenius/gemini/google", ["langgenius/gemini/google", "google"]), + ], +) +def test_get_provider_names_returns_short_and_full_aliases(provider_name: str, expected_provider_names: list[str]): + assert ProviderManager._get_provider_names(provider_name) == expected_provider_names + + +def test_get_provider_model_bundle_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + +def test_get_configurations_binds_manager_runtime_to_provider_configuration( + mocker: MockerFixture, mock_provider_entity +): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + provider_factory = Mock() + provider_factory.get_providers.return_value = [mock_provider_entity] + custom_configuration = SimpleNamespace(provider=None, models=[]) + system_configuration = SimpleNamespace(enabled=False, quota_configurations=[], current_quota_type=None) + + with ( + patch.object(manager, "_get_all_providers", return_value={"openai": []}), + patch.object(manager, "_init_trial_provider_records", return_value={"openai": []}), + patch.object(manager, "_get_all_provider_models", return_value={"openai": []}), + patch.object(manager, "_get_all_preferred_model_providers", return_value={}), + patch.object(manager, "_get_all_provider_model_settings", return_value={}), + patch.object(manager, "_get_all_provider_load_balancing_configs", return_value={}), + patch.object(manager, "_get_all_provider_model_credentials", return_value={}), + patch.object(manager, "_to_custom_configuration", return_value=custom_configuration), + patch.object(manager, "_to_system_configuration", return_value=system_configuration), + patch.object(manager, "_to_model_settings", return_value=[]), + patch("core.provider_manager.ModelProviderFactory", return_value=provider_factory), + patch("core.provider_manager.ProviderConfiguration", return_value=provider_configuration), + ): + manager.get_configurations("tenant-id") + + provider_configuration.bind_model_runtime.assert_called_once_with(manager._model_runtime) + + +def test_get_provider_model_bundle_returns_selected_model_type_instance(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configuration = Mock() + model_type_instance = Mock() + provider_configuration.get_model_type_instance.return_value = model_type_instance + expected_bundle = Mock() + + with ( + patch.object(manager, "get_configurations", return_value={"openai": provider_configuration}), + patch("core.provider_manager.ProviderModelBundle", return_value=expected_bundle) as mock_bundle, + ): + result = manager.get_provider_model_bundle("tenant-id", "openai", ModelType.LLM) + + provider_configuration.get_model_type_instance.assert_called_once_with(ModelType.LLM) + mock_bundle.assert_called_once_with( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + assert result is expected_bundle + + +def test_get_first_provider_first_model_returns_none_when_no_models(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == (None, None) + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=False) + + +def test_get_first_provider_first_model_returns_first_model_and_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = Mock() + provider_configurations.get_models.return_value = [ + Mock(model="gpt-4", provider=Mock(provider="openai")), + Mock(model="gpt-4o", provider=Mock(provider="openai")), + ] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + result = manager.get_first_provider_first_model("tenant-id", ModelType.LLM) + + assert result == ("openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_provider(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + + with patch.object(manager, "get_configurations", return_value={}): + with pytest.raises(ValueError, match="Provider openai does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + +def test_update_default_model_record_raises_for_unknown_model(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + + with patch.object(manager, "get_configurations", return_value=provider_configurations): + with pytest.raises(ValueError, match="Model gpt-3.5-turbo does not exist."): + manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + provider_configurations.get_models.assert_called_once_with(model_type=ModelType.LLM, only_active=True) + + +def test_update_default_model_record_updates_existing_record(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-3.5-turbo")] + existing_default_model = TenantDefaultModel( + tenant_id="tenant-id", + provider_name="anthropic", + model_name="claude-3-sonnet", + model_type=ModelType.LLM.to_origin_model_type(), + ) + mock_session = Mock() + mock_session.scalar.return_value = existing_default_model + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-3.5-turbo") + + assert result is existing_default_model + assert existing_default_model.provider_name == "openai" + assert existing_default_model.model_name == "gpt-3.5-turbo" + mock_session.commit.assert_called_once() + mock_session.add.assert_not_called() + + +def test_update_default_model_record_creates_record_with_origin_model_type(mocker: MockerFixture): + manager = _build_provider_manager(mocker) + provider_configurations = MagicMock() + provider_configurations.__contains__.return_value = True + provider_configurations.get_models.return_value = [Mock(model="gpt-4")] + mock_session = Mock() + mock_session.scalar.return_value = None + + with ( + patch.object(manager, "get_configurations", return_value=provider_configurations), + patch("core.provider_manager.db.session", mock_session), + ): + result = manager.update_default_model_record("tenant-id", ModelType.LLM, "openai", "gpt-4") + + mock_session.add.assert_called_once() + created_default_model = mock_session.add.call_args.args[0] + assert result is created_default_model + assert created_default_model.tenant_id == "tenant-id" + assert created_default_model.provider_name == "openai" + assert created_default_model.model_name == "gpt-4" + assert created_default_model.model_type == ModelType.LLM.to_origin_model_type() + mock_session.commit.assert_called_once() + + +def test_get_all_providers_normalizes_provider_names_with_model_provider_id() -> None: + session = Mock() + openai_provider = SimpleNamespace(provider_name="openai") + gemini_provider = SimpleNamespace(provider_name="langgenius/gemini/google") + session.scalars.return_value = [openai_provider, gemini_provider] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_providers("tenant-id") + + assert list(result[str(ModelProviderID("openai"))]) == [openai_provider] + assert list(result[str(ModelProviderID("langgenius/gemini/google"))]) == [gemini_provider] + + +@pytest.mark.parametrize( + "method_name", + [ + "_get_all_provider_models", + "_get_all_provider_model_settings", + "_get_all_provider_model_credentials", + ], +) +def test_provider_grouping_helpers_group_records_by_provider_name(method_name: str) -> None: + session = Mock() + openai_primary = SimpleNamespace(provider_name="openai") + openai_secondary = SimpleNamespace(provider_name="openai") + anthropic_record = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_primary, openai_secondary, anthropic_record] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = getattr(ProviderManager, method_name)("tenant-id") + + assert list(result["openai"]) == [openai_primary, openai_secondary] + assert list(result["anthropic"]) == [anthropic_record] + + +def test_get_all_preferred_model_providers_returns_mapping_by_provider_name() -> None: + session = Mock() + openai_preference = SimpleNamespace(provider_name="openai") + anthropic_preference = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_preference, anthropic_preference] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_preferred_model_providers("tenant-id") + + assert result == { + "openai": openai_preference, + "anthropic": anthropic_preference, + } + + +def test_get_all_provider_load_balancing_configs_returns_empty_when_cached_flag_is_disabled() -> None: + with ( + patch("core.provider_manager.redis_client.get", return_value=b"False"), + patch("core.provider_manager.FeatureService.get_features") as mock_get_features, + patch("core.provider_manager.Session") as mock_session_cls, + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + assert result == {} + mock_get_features.assert_not_called() + mock_session_cls.assert_not_called() + + +def test_get_all_provider_load_balancing_configs_populates_cache_and_groups_configs() -> None: + session = Mock() + openai_config = SimpleNamespace(provider_name="openai") + anthropic_config = SimpleNamespace(provider_name="anthropic") + session.scalars.return_value = [openai_config, anthropic_config] + + with ( + patch("core.provider_manager.db", SimpleNamespace(engine=object())), + patch("core.provider_manager.redis_client.get", return_value=None), + patch("core.provider_manager.redis_client.setex") as mock_setex, + patch( + "core.provider_manager.FeatureService.get_features", + return_value=SimpleNamespace(model_load_balancing_enabled=True), + ), + patch("core.provider_manager.Session", return_value=_build_session_context(session)), + ): + result = ProviderManager._get_all_provider_load_balancing_configs("tenant-id") + + mock_setex.assert_called_once_with("tenant:tenant-id:model_load_balancing_enabled", 120, "True") + assert list(result["openai"]) == [openai_config] + assert list(result["anthropic"]) == [anthropic_config] diff --git a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py index f123f60a34..5d744f88c9 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tool_base.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tool_base.py @@ -6,13 +6,13 @@ from typing import Any from unittest.mock import patch import pytest +from graphon.model_runtime.entities.message_entities import UserPromptMessage from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage, ToolProviderType -from dify_graph.model_runtime.entities.message_entities import UserPromptMessage class _BuiltinDummyTool(BuiltinTool): @@ -27,12 +27,12 @@ class _BuiltinDummyTool(BuiltinTool): yield self.create_text_message("ok") -def _build_tool() -> _BuiltinDummyTool: +def _build_tool(user_id: str | None = None) -> _BuiltinDummyTool: entity = ToolEntity( identity=ToolIdentity(author="author", name="tool-a", label=I18nObject(en_US="tool-a"), provider="provider-a"), parameters=[], ) - runtime = ToolRuntime(tenant_id="tenant-1", invoke_from=InvokeFrom.DEBUGGER) + runtime = ToolRuntime(tenant_id="tenant-1", user_id=user_id, invoke_from=InvokeFrom.DEBUGGER) return _BuiltinDummyTool(provider="provider-a", entity=entity, runtime=runtime) @@ -45,7 +45,7 @@ def test_builtin_tool_fork_and_provider_type(): def test_invoke_model_calls_model_invocation_utils_invoke(): - tool = _build_tool() + tool = _build_tool(user_id="runtime-user") with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.invoke", return_value="result") as mock_invoke: assert ( tool.invoke_model( @@ -55,19 +55,47 @@ def test_invoke_model_calls_model_invocation_utils_invoke(): ) == "result" ) - mock_invoke.assert_called_once() + mock_invoke.assert_called_once_with( + user_id="u1", + tenant_id="tenant-1", + tool_type=ToolProviderType.BUILT_IN, + tool_name="tool-a", + prompt_messages=[UserPromptMessage(content="hello")], + caller_user_id="runtime-user", + ) def test_get_max_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096): + tool = _build_tool(user_id="runtime-user") + with patch( + "core.tools.builtin_tool.tool.ModelInvocationUtils.get_max_llm_context_tokens", return_value=4096 + ) as mock_get: assert tool.get_max_tokens() == 4096 + mock_get.assert_called_once_with(tenant_id="tenant-1", user_id="runtime-user") def test_get_prompt_tokens_returns_value(): - tool = _build_tool() - with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7): + tool = _build_tool(user_id="runtime-user") + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id="runtime-user", + ) + + +def test_get_prompt_tokens_falls_back_to_tenant_scope_when_runtime_user_id_missing(): + tool = _build_tool() + + with patch("core.tools.builtin_tool.tool.ModelInvocationUtils.calculate_tokens", return_value=7) as mock_calculate: + assert tool.get_prompt_tokens([UserPromptMessage(content="hello")]) == 7 + + mock_calculate.assert_called_once_with( + tenant_id="tenant-1", + prompt_messages=[UserPromptMessage(content="hello")], + user_id=None, + ) def test_runtime_none_raises(): diff --git a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py index 62cfb6ce5b..ee0ce51eec 100644 --- a/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py +++ b/api/tests/unit_tests/core/tools/test_builtin_tools_extra.py @@ -1,9 +1,13 @@ from __future__ import annotations +import calendar import math +from datetime import date from types import SimpleNamespace import pytest +from graphon.file import FileType +from graphon.model_runtime.entities.model_entities import ModelPropertyKey from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -25,8 +29,6 @@ from core.tools.builtin_tool.tool import BuiltinTool from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.errors import ToolInvokeError -from dify_graph.file.enums import FileType -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey def _build_builtin_tool(tool_cls: type[BuiltinTool]) -> BuiltinTool: @@ -98,7 +100,13 @@ def test_timezone_conversion_tool(): def test_weekday_tool(): weekday_tool = _build_builtin_tool(WeekdayTool) valid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 1, "day": 1}))[0].message.text - assert "January 1, 2024" in valid + expected_date = date(2024, 1, 1) + expected_message = ( + f"{calendar.month_name[expected_date.month]} " + f"{expected_date.day}, {expected_date.year} " + f"is {calendar.day_name[expected_date.weekday()]}." + ) + assert valid == expected_message invalid = list(weekday_tool.invoke(user_id="u", tool_parameters={"year": 2024, "month": 2, "day": 31}))[ 0 ].message.text @@ -186,13 +194,19 @@ def test_asr_invalid_file(): def test_asr_valid_file_invocation(monkeypatch): asr = _build_builtin_tool(ASRTool) - model_instance = type("M", (), {"invoke_speech2text": lambda self, file, user: "transcript"})() + model_instance = type("M", (), {"invoke_speech2text": lambda self, file: "transcript"})() model_manager = type("Mgr", (), {"get_model_instance": lambda *a, **k: model_instance})() monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.download", lambda file: b"audio-bytes") - monkeypatch.setattr("core.tools.builtin_tool.providers.audio.tools.asr.ModelManager", lambda: model_manager) + captured_manager_kwargs = {} + + monkeypatch.setattr( + "core.tools.builtin_tool.providers.audio.tools.asr.ModelManager.for_tenant", + lambda **kwargs: captured_manager_kwargs.update(kwargs) or model_manager, + ) audio_file = SimpleNamespace(type=FileType.AUDIO) ok = list(asr.invoke(user_id="u", tool_parameters={"audio_file": audio_file, "model": "p#m"}))[0].message.text assert ok == "transcript" + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_asr_available_models_and_runtime_parameters(monkeypatch): @@ -208,6 +222,7 @@ def test_asr_available_models_and_runtime_parameters(monkeypatch): def test_tts_invoke_returns_messages(monkeypatch): tts = _build_builtin_tool(TTSTool) + captured_manager_kwargs = {} voices_model_instance = type( "TTSM", (), @@ -217,11 +232,15 @@ def test_tts_invoke_returns_messages(monkeypatch): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **kwargs: ( + captured_manager_kwargs.update(kwargs) + or type("M", (), {"get_model_instance": lambda *a, **k: voices_model_instance})() + ), ) messages = list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) assert [m.type for m in messages] == [ToolInvokeMessage.MessageType.TEXT, ToolInvokeMessage.MessageType.BLOB] + assert captured_manager_kwargs == {"tenant_id": "tenant-1", "user_id": "u"} def test_tts_get_available_models_requires_runtime(): @@ -254,8 +273,8 @@ def test_tts_tool_raises_when_voice_unavailable(monkeypatch, voices): }, )() monkeypatch.setattr( - "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager", - lambda: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), + "core.tools.builtin_tool.providers.audio.tools.tts.ModelManager.for_tenant", + lambda **_: type("Manager", (), {"get_model_instance": lambda *args, **kwargs: model_without_voice})(), ) with pytest.raises(ValueError, match="no voice available"): list(tts.invoke(user_id="u", tool_parameters={"model": "p#m", "text": "hello"})) diff --git a/api/tests/unit_tests/core/tools/test_signature.py b/api/tests/unit_tests/core/tools/test_signature.py index a5242a78c5..353988d7a6 100644 --- a/api/tests/unit_tests/core/tools/test_signature.py +++ b/api/tests/unit_tests/core/tools/test_signature.py @@ -6,7 +6,13 @@ from urllib.parse import parse_qs, urlparse import pytest -from core.tools.signature import sign_tool_file, sign_upload_file, verify_tool_file_signature +from core.tools.signature import ( + get_signed_file_url_for_plugin, + sign_tool_file, + sign_upload_file, + verify_plugin_file_signature, + verify_tool_file_signature, +) def test_sign_tool_file_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: @@ -117,3 +123,82 @@ def test_sign_upload_file_uses_files_url_fallback(monkeypatch: pytest.MonkeyPatc assert query["timestamp"][0] assert query["nonce"][0] assert query["sign"][0] + + +def test_get_signed_file_url_for_plugin_and_verify_roundtrip(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x06" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "https://internal.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 60) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + parsed = urlparse(url) + query = parse_qs(parsed.query) + + assert parsed.netloc == "internal.example.com" + assert parsed.path == "/files/upload/for-plugin" + assert query["tenant_id"] == ["tenant-id"] + assert query["user_id"] == ["user-id"] + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is True + ) + + +def test_verify_plugin_file_signature_rejects_invalid_signatures(monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000000) + monkeypatch.setattr("core.tools.signature.os.urandom", lambda _: b"\x07" * 16) + monkeypatch.setattr("core.tools.signature.dify_config.SECRET_KEY", "unit-secret") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_URL", "https://files.example.com") + monkeypatch.setattr("core.tools.signature.dify_config.INTERNAL_FILES_URL", "") + monkeypatch.setattr("core.tools.signature.dify_config.FILES_ACCESS_TIMEOUT", 30) + + url = get_signed_file_url_for_plugin( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + ) + query = parse_qs(urlparse(url).query) + + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign="bad-signature", + ) + is False + ) + + monkeypatch.setattr("core.tools.signature.time.time", lambda: 1700000100) + assert ( + verify_plugin_file_signature( + filename="report.pdf", + mimetype="application/pdf", + tenant_id="tenant-id", + user_id="user-id", + timestamp=query["timestamp"][0], + nonce=query["nonce"][0], + sign=query["sign"][0], + ) + is False + ) diff --git a/api/tests/unit_tests/core/tools/test_tool_file_manager.py b/api/tests/unit_tests/core/tools/test_tool_file_manager.py index cca8254dd6..7fcebde3c5 100644 --- a/api/tests/unit_tests/core/tools/test_tool_file_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_file_manager.py @@ -12,6 +12,7 @@ from unittest.mock import MagicMock, Mock, patch import httpx import pytest +from graphon.file import FileTransferMethod from core.tools.tool_file_manager import ToolFileManager @@ -232,7 +233,14 @@ def test_get_file_generator_returns_none_when_toolfile_missing() -> None: def test_get_file_generator_returns_stream_when_found() -> None: # Arrange manager = ToolFileManager() - tool_file = SimpleNamespace(file_key="k2", mimetype="image/png") + tool_file = SimpleNamespace( + id="tool123", + file_key="k2", + mimetype="image/png", + original_url=None, + name="image.png", + size=12, + ) session = Mock() session.query.return_value.where.return_value.first.return_value = tool_file @@ -240,10 +248,10 @@ def test_get_file_generator_returns_stream_when_found() -> None: with patch("core.tools.tool_file_manager.storage") as storage: stream = iter([b"a", b"b"]) storage.load_stream.return_value = stream - with ( - _patch_session_factory(session), - patch("core.tools.tool_file_manager.ToolFilePydanticModel.model_validate", return_value="validated-file"), - ): + with _patch_session_factory(session): result_stream, result_file = manager.get_file_generator_by_tool_file_id("tool123") assert list(result_stream) == [b"a", b"b"] - assert result_file == "validated-file" + assert result_file is not None + assert result_file.related_id == "tool123" + assert result_file.mime_type == "image/png" + assert result_file.transfer_method == FileTransferMethod.TOOL_FILE diff --git a/api/tests/unit_tests/core/tools/test_tool_label_manager.py b/api/tests/unit_tests/core/tools/test_tool_label_manager.py index 857f4aa178..8c0e7e9419 100644 --- a/api/tests/unit_tests/core/tools/test_tool_label_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_label_manager.py @@ -38,11 +38,9 @@ def test_tool_label_manager_filter_tool_labels(): def test_tool_label_manager_update_tool_labels_db(): controller = _api_controller("api-1") with patch("core.tools.tool_label_manager.db") as mock_db: - delete_query = mock_db.session.query.return_value.where.return_value - delete_query.delete.return_value = None ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"]) - delete_query.delete.assert_called_once() + mock_db.session.execute.assert_called_once() # only one valid unique label should be inserted. assert mock_db.session.add.call_count == 1 mock_db.session.commit.assert_called_once() diff --git a/api/tests/unit_tests/core/tools/test_tool_manager.py b/api/tests/unit_tests/core/tools/test_tool_manager.py index 0f73e22654..31b68f0b3f 100644 --- a/api/tests/unit_tests/core/tools/test_tool_manager.py +++ b/api/tests/unit_tests/core/tools/test_tool_manager.py @@ -15,6 +15,7 @@ from core.plugin.entities.plugin_daemon import CredentialType from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.tool_entities import ( ApiProviderAuthType, + ToolInvokeFrom, ToolParameter, ToolProviderType, ) @@ -219,9 +220,7 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks(): with patch.object(ToolManager, "get_builtin_provider", return_value=controller): with patch("core.helper.credential_utils.check_credential_policy_compliance"): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = ( - builtin_provider - ) + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"api_key": "secret"} cache = Mock() @@ -273,7 +272,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials( ) refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456) - mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider + mock_db.session.scalar.return_value = builtin_provider encrypter = Mock() encrypter.decrypt.return_value = {"token": "old"} encrypter.encrypt.return_value = {"token": "encrypted"} @@ -421,7 +420,7 @@ def test_get_agent_runtime_apply_runtime_parameters(): tool_runtime = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "hello"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "decrypted"} @@ -437,12 +436,23 @@ def test_get_agent_runtime_apply_runtime_parameters(): tenant_id="tenant-1", app_id="app-1", agent_tool=agent_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert result is tool_runtime assert tool_runtime.runtime.runtime_parameters["query"] == "decrypted" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.AGENT, + credential_id=None, + ) def test_get_workflow_runtime_apply_runtime_parameters(): @@ -463,7 +473,7 @@ def test_get_workflow_runtime_apply_runtime_parameters(): ) tool_runtime2 = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_runtime2.get_merged_runtime_parameters = Mock(return_value=[parameter]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_runtime2) as mock_get_tool_runtime: with patch.object(ToolManager, "_convert_tool_parameters_type", return_value={"query": "workflow"}): manager = Mock() manager.decrypt_tool_parameters.return_value = {"query": "workflow-dec"} @@ -473,12 +483,23 @@ def test_get_workflow_runtime_apply_runtime_parameters(): app_id="app-1", node_id="node-1", workflow_tool=workflow_tool, + user_id="user-1", invoke_from=InvokeFrom.DEBUGGER, variable_pool=None, ) assert workflow_result is tool_runtime2 assert tool_runtime2.runtime.runtime_parameters["query"] == "workflow-dec" + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.DEBUGGER, + tool_invoke_from=ToolInvokeFrom.WORKFLOW, + credential_id=None, + ) def test_get_agent_runtime_raises_when_runtime_missing(): @@ -520,17 +541,28 @@ def test_get_tool_runtime_from_plugin_only_uses_form_parameters(): tool_entity = SimpleNamespace(runtime=ToolRuntime(tenant_id="tenant-1", runtime_parameters={})) tool_entity.get_merged_runtime_parameters = Mock(return_value=[form_param, llm_param]) - with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity): + with patch.object(ToolManager, "get_tool_runtime", return_value=tool_entity) as mock_get_tool_runtime: result = ToolManager.get_tool_runtime_from_plugin( tool_type=ToolProviderType.API, tenant_id="tenant-1", provider="api-1", tool_name="search", tool_parameters={"q": "hello", "llm": "ignore"}, + user_id="user-1", ) assert result is tool_entity assert tool_entity.runtime.runtime_parameters == {"q": "hello"} + mock_get_tool_runtime.assert_called_once_with( + provider_type=ToolProviderType.API, + provider_id="api-1", + tool_name="search", + tenant_id="tenant-1", + user_id="user-1", + invoke_from=InvokeFrom.SERVICE_API, + tool_invoke_from=ToolInvokeFrom.PLUGIN, + credential_id=None, + ) def test_hardcoded_provider_icon_success(): @@ -664,12 +696,10 @@ def test_get_api_provider_controller_returns_controller_and_credentials(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch( "core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller ) as mock_from_db: @@ -696,12 +726,10 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): privacy_policy="privacy", custom_disclaimer="disclaimer", ) - db_query = Mock() - db_query.where.return_value.first.return_value = provider controller = Mock() with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value = db_query + mock_db.session.scalar.return_value = provider with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller): encrypter = Mock() encrypter.decrypt.return_value = {"api_key_value": "secret"} @@ -716,7 +744,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels(): def test_get_api_provider_controller_not_found_raises(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"): ToolManager.get_api_provider_controller("tenant-1", "missing") @@ -775,14 +803,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api(): workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}') api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}') with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider] + mock_db.session.scalar.side_effect = [workflow_provider, api_provider] assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"} assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"} def test_generate_tool_icon_urls_missing_workflow_and_api_use_default(): with patch("core.tools.tool_manager.db") as mock_db: - mock_db.session.query.return_value.where.return_value.first.return_value = None + mock_db.session.scalar.return_value = None assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525" assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525" diff --git a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py index af3cdddd5f..6454a5bcd1 100644 --- a/api/tests/unit_tests/core/tools/utils/test_message_transformer.py +++ b/api/tests/unit_tests/core/tools/utils/test_message_transformer.py @@ -84,3 +84,24 @@ def test_transform_tool_invoke_messages_mimetype_key_present_but_none(): # meta is preserved (still contains mime_type: None) assert "mime_type" in (o.meta or {}) assert o.meta["mime_type"] is None + assert o.meta["tool_file_id"] == "fake-tool-file-id" + + +def test_transform_tool_invoke_messages_parses_existing_tool_file_link_meta(): + msg = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.IMAGE_LINK, + message=ToolInvokeMessage.TextMessage(text="/files/tools/existing-tool-file.png"), + meta={}, + ) + + out = list( + mt.ToolFileMessageTransformer.transform_tool_invoke_messages( + messages=_gen([msg]), + user_id="u1", + tenant_id="t1", + conversation_id="c1", + ) + ) + + assert len(out) == 1 + assert out[0].meta["tool_file_id"] == "existing-tool-file" diff --git a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py index 4ce73272bf..a93624123e 100644 --- a/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py +++ b/api/tests/unit_tests/core/tools/utils/test_misc_utils_extra.py @@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources(): ) db_session = Mock() db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high] - db_session.query.return_value.filter_by.return_value.first.return_value = dataset + db_session.get.return_value = dataset tool = SingleDatasetRetrieverTool( tenant_id="tenant-1", @@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources(): ) db_session = Mock() db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1] - db_session.query.return_value.filter_by.return_value.first.side_effect = [ + db_session.get.side_effect = [ SimpleNamespace(id="dataset-2", name="Dataset Two"), SimpleNamespace(id="dataset-1", name="Dataset One"), ] diff --git a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py index 2acae889b2..52f262e1cf 100644 --- a/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py +++ b/api/tests/unit_tests/core/tools/utils/test_model_invocation_utils.py @@ -13,10 +13,8 @@ from types import SimpleNamespace from unittest.mock import Mock, patch import pytest - -from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils -from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey -from dify_graph.model_runtime.errors.invoke import ( +from graphon.model_runtime.entities.model_entities import ModelPropertyKey +from graphon.model_runtime.errors.invoke import ( InvokeAuthorizationError, InvokeBadRequestError, InvokeConnectionError, @@ -24,6 +22,8 @@ from dify_graph.model_runtime.errors.invoke import ( InvokeServerUnavailableError, ) +from core.tools.utils.model_invocation_utils import InvokeModelError, ModelInvocationUtils + def _mock_model_instance(*, schema: dict | None = None) -> SimpleNamespace: model_type_instance = Mock() @@ -60,20 +60,23 @@ def test_get_max_llm_context_tokens_branches(model_instance, expected, error_mat manager = Mock() manager.get_default_model_instance.return_value = model_instance - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: if error_match: with pytest.raises(InvokeModelError, match=error_match): - ModelInvocationUtils.get_max_llm_context_tokens("tenant") + ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") else: - assert ModelInvocationUtils.get_max_llm_context_tokens("tenant") == expected + assert ModelInvocationUtils.get_max_llm_context_tokens("tenant", user_id="user-1") == expected + + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="user-1") def test_calculate_tokens_handles_missing_model(): manager = Mock() manager.get_default_model_instance.return_value = None - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with pytest.raises(InvokeModelError, match="Model not found"): ModelInvocationUtils.calculate_tokens("tenant", []) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id=None) def test_invoke_success_and_error_mappings(): @@ -98,7 +101,7 @@ def test_invoke_success_and_error_mappings(): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): response = ModelInvocationUtils.invoke( @@ -107,11 +110,13 @@ def test_invoke_success_and_error_mappings(): tool_type="builtin", tool_name="tool-a", prompt_messages=[], + caller_user_id="caller-1", ) assert response.message.content == "ok" assert db_mock.session.add.call_count == 1 assert db_mock.session.commit.call_count == 2 + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="caller-1") @pytest.mark.parametrize( @@ -145,7 +150,7 @@ def test_invoke_error_mappings(exc, expected): db_mock = SimpleNamespace(session=Mock()) - with patch("core.tools.utils.model_invocation_utils.ModelManager", return_value=manager): + with patch("core.tools.utils.model_invocation_utils.ModelManager.for_tenant", return_value=manager) as mock_factory: with patch("core.tools.utils.model_invocation_utils.ToolModelInvoke", _ToolModelInvoke): with patch("core.tools.utils.model_invocation_utils.db", db_mock): with pytest.raises(InvokeModelError, match=expected): @@ -156,3 +161,4 @@ def test_invoke_error_mappings(exc, expected): tool_name="tool-a", prompt_messages=[], ) + mock_factory.assert_called_once_with(tenant_id="tenant", user_id="u1") diff --git a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py index dd79b79718..0e3a7e623a 100644 --- a/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py +++ b/api/tests/unit_tests/core/tools/utils/test_workflow_configuration_sync.py @@ -1,9 +1,9 @@ import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration from core.tools.errors import WorkflowToolHumanInputNotSupportedError from core.tools.utils.workflow_configuration_sync import WorkflowToolConfigurationUtils -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def test_ensure_no_human_input_nodes_passes_for_non_human_input(): diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py index dd140cbb27..2607861b59 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_provider.py @@ -4,6 +4,7 @@ from types import SimpleNamespace from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.variables.input_entities import VariableEntity, VariableEntityType from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ( @@ -13,7 +14,6 @@ from core.tools.entities.tool_entities import ( ToolProviderType, ) from core.tools.workflow_as_tool.provider import WorkflowToolProviderController -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType def _controller() -> WorkflowToolProviderController: diff --git a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py index cc00f79698..c20edd7400 100644 --- a/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py +++ b/api/tests/unit_tests/core/tools/workflow_as_tool/test_tool.py @@ -11,6 +11,7 @@ from typing import Any from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.file import FILE_MODEL_IDENTITY, FileTransferMethod, FileType from core.app.entities.app_invoke_entities import InvokeFrom from core.tools.__base.tool_runtime import ToolRuntime @@ -24,7 +25,6 @@ from core.tools.entities.tool_entities import ( ) from core.tools.errors import ToolInvokeError from core.tools.workflow_as_tool.tool import WorkflowTool -from dify_graph.file import FILE_MODEL_IDENTITY class StubScalars: @@ -439,6 +439,32 @@ def _setup_transform_args_tool(monkeypatch: pytest.MonkeyPatch) -> WorkflowTool: def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): """Transform args into parameters and files payloads.""" tool = _setup_transform_args_tool(monkeypatch) + build_file_from_stored_mapping = MagicMock( + side_effect=[ + SimpleNamespace( + transfer_method=FileTransferMethod.TOOL_FILE, + type=FileType.IMAGE, + reference="tool-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.LOCAL_FILE, + type=FileType.DOCUMENT, + reference="upload-1", + generate_url=lambda: None, + ), + SimpleNamespace( + transfer_method=FileTransferMethod.REMOTE_URL, + type=FileType.DOCUMENT, + reference=None, + generate_url=lambda: "https://example.com/a.pdf", + ), + ] + ) + monkeypatch.setattr( + "core.tools.workflow_as_tool.tool.build_file_from_stored_mapping", + build_file_from_stored_mapping, + ) params, files = tool._transform_args( { @@ -470,6 +496,8 @@ def test_transform_args_valid_files(monkeypatch: pytest.MonkeyPatch): assert any(file_item.get("tool_file_id") == "tool-1" for file_item in files) assert any(file_item.get("upload_file_id") == "upload-1" for file_item in files) assert any(file_item.get("url") == "https://example.com/a.pdf" for file_item in files) + assert build_file_from_stored_mapping.call_count == 3 + assert all(call.kwargs["tenant_id"] == "test_tool" for call in build_file_from_stored_mapping.call_args_list) def test_transform_args_invalid_files(monkeypatch: pytest.MonkeyPatch): diff --git a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py index bcb1d745e3..78622b78b6 100644 --- a/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py +++ b/api/tests/unit_tests/core/trigger/debug/test_debug_event_selectors.py @@ -11,6 +11,7 @@ from datetime import datetime from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes, NodeType from core.plugin.entities.request import TriggerInvokeEventResponse from core.trigger.constants import ( @@ -26,7 +27,6 @@ from core.trigger.debug.event_selectors import ( select_trigger_debug_events, ) from core.trigger.debug.events import PluginTriggerDebugEvent, WebhookDebugEvent -from dify_graph.enums import BuiltinNodeTypes, NodeType from tests.unit_tests.core.trigger.conftest import VALID_PROVIDER_ID diff --git a/api/tests/unit_tests/core/variables/test_segment.py b/api/tests/unit_tests/core/variables/test_segment.py index 91259c9a45..7406b88270 100644 --- a/api/tests/unit_tests/core/variables/test_segment.py +++ b/api/tests/unit_tests/core/variables/test_segment.py @@ -2,14 +2,10 @@ import dataclasses import orjson import pytest -from pydantic import BaseModel - -from core.helper import encrypter -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import VariablePool +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ( ArrayAnySegment, ArrayFileSegment, ArrayNumberSegment, @@ -25,13 +21,13 @@ from dify_graph.variables.segments import ( StringSegment, get_segment_discriminator, ) -from dify_graph.variables.types import SegmentType -from dify_graph.variables.utils import ( +from graphon.variables.types import SegmentType +from graphon.variables.utils import ( dumps_with_segments, segment_orjson_default, to_selector, ) -from dify_graph.variables.variables import ( +from graphon.variables.variables import ( ArrayAnyVariable, ArrayFileVariable, ArrayNumberVariable, @@ -46,16 +42,35 @@ from dify_graph.variables.variables import ( StringVariable, Variable, ) +from pydantic import BaseModel + +from core.helper import encrypter +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_variables_to_pool + + +def _build_variable_pool( + *, + system_variables: list[Variable] | None = None, + environment_variables: list[Variable] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables or [], + environment_variables=environment_variables or [], + ), + ) + return variable_pool def test_segment_group_to_text(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="fake-user-id"), environment_variables=[ SecretVariable(name="secret_key", value="fake-secret-key"), ], - conversation_variables=[], ) variable_pool.add(("node_id", "custom_query"), "fake-user-query") template = ( @@ -71,11 +86,8 @@ def test_segment_group_to_text(): def test_convert_constant_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="1", app_id="1", workflow_id="1"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], + variable_pool = _build_variable_pool( + system_variables=build_system_variables(user_id="1", app_id="1", workflow_id="1"), ) template = "Hello, world!" segments_group = variable_pool.convert_template(template) @@ -84,12 +96,7 @@ def test_convert_constant_to_segment_group(): def test_convert_variable_to_segment_group(): - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="fake-user-id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) + variable_pool = _build_variable_pool(system_variables=build_system_variables(user_id="fake-user-id")) template = "{{#sys.user_id#}}" segments_group = variable_pool.convert_template(template) assert segments_group.text == "fake-user-id" @@ -116,7 +123,6 @@ def create_test_file( ) -> File: """Factory function to create File objects for testing""" return File( - tenant_id="test-tenant", type=file_type, transfer_method=transfer_method, filename=filename, @@ -190,7 +196,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_segment.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: @@ -234,7 +239,6 @@ class TestSegmentDumpAndLoad: loaded_file = loaded_variable.value assert isinstance(orig_file, File) assert isinstance(loaded_file, File) - assert loaded_file.tenant_id == orig_file.tenant_id assert loaded_file.type == orig_file.type assert loaded_file.filename == orig_file.filename else: diff --git a/api/tests/unit_tests/core/variables/test_segment_type.py b/api/tests/unit_tests/core/variables/test_segment_type.py index bb234d9bbd..37ecd2890b 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type.py +++ b/api/tests/unit_tests/core/variables/test_segment_type.py @@ -1,8 +1,7 @@ import pytest - -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import ArrayValidation, SegmentType +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import StringSegment +from graphon.variables.types import ArrayValidation, SegmentType class TestSegmentTypeIsArrayType: diff --git a/api/tests/unit_tests/core/variables/test_segment_type_validation.py b/api/tests/unit_tests/core/variables/test_segment_type_validation.py index 41ce483447..09254e17a3 100644 --- a/api/tests/unit_tests/core/variables/test_segment_type_validation.py +++ b/api/tests/unit_tests/core/variables/test_segment_type_validation.py @@ -9,11 +9,9 @@ from dataclasses import dataclass from typing import Any import pytest - -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File -from dify_graph.variables.segment_group import SegmentGroup -from dify_graph.variables.segments import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segment_group import SegmentGroup +from graphon.variables.segments import ( ArrayFileSegment, BooleanSegment, FileSegment, @@ -22,7 +20,7 @@ from dify_graph.variables.segments import ( ObjectSegment, StringSegment, ) -from dify_graph.variables.types import ArrayValidation, SegmentType +from graphon.variables.types import ArrayValidation, SegmentType def create_test_file( diff --git a/api/tests/unit_tests/core/variables/test_variables.py b/api/tests/unit_tests/core/variables/test_variables.py index dd0fe2e65a..75b01bf42e 100644 --- a/api/tests/unit_tests/core/variables/test_variables.py +++ b/api/tests/unit_tests/core/variables/test_variables.py @@ -1,7 +1,5 @@ import pytest -from pydantic import ValidationError - -from dify_graph.variables import ( +from graphon.variables import ( ArrayFileVariable, ArrayVariable, FloatVariable, @@ -11,7 +9,8 @@ from dify_graph.variables import ( SegmentType, StringVariable, ) -from dify_graph.variables.variables import VariableBase +from graphon.variables.variables import VariableBase +from pydantic import ValidationError def test_frozen_variables(): diff --git a/api/tests/unit_tests/core/workflow/context/test_execution_context.py b/api/tests/unit_tests/core/workflow/context/test_execution_context.py index d09b8397c3..3ce4bb753b 100644 --- a/api/tests/unit_tests/core/workflow/context/test_execution_context.py +++ b/api/tests/unit_tests/core/workflow/context/test_execution_context.py @@ -9,7 +9,7 @@ from unittest.mock import MagicMock import pytest from pydantic import BaseModel -from dify_graph.context.execution_context import ( +from context.execution_context import ( AppContext, ExecutionContext, ExecutionContextBuilder, @@ -286,7 +286,7 @@ class TestCaptureCurrentContext: def test_capture_current_context_returns_context(self): """Test that capture_current_context returns a valid context.""" - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -303,7 +303,7 @@ class TestCaptureCurrentContext: test_var = contextvars.ContextVar("capture_test_var") test_var.set("test_value_123") - from dify_graph.context.execution_context import capture_current_context + from context.execution_context import capture_current_context result = capture_current_context() @@ -313,12 +313,12 @@ class TestCaptureCurrentContext: class TestTenantScopedContextRegistry: def setup_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() def teardown_method(self): - from dify_graph.context import reset_context_provider + from context import reset_context_provider reset_context_provider() @@ -333,7 +333,7 @@ class TestTenantScopedContextRegistry: assert read_context("workflow.sandbox", tenant_id="t2").base_url == "http://t2" def test_missing_provider_raises_keyerror(self): - from dify_graph.context import ContextProviderNotFoundError + from context import ContextProviderNotFoundError with pytest.raises(ContextProviderNotFoundError): read_context("missing", tenant_id="unknown") diff --git a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py b/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py deleted file mode 100644 index 22792eb5b3..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_graph_runtime_state.py +++ /dev/null @@ -1,296 +0,0 @@ -import json -from time import time -from unittest.mock import MagicMock, patch - -import pytest - -from dify_graph.constants import CONVERSATION_VARIABLE_NODE_ID -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState, ReadOnlyGraphRuntimeStateWrapper, VariablePool -from dify_graph.variables.variables import StringVariable - - -class StubCoordinator: - def __init__(self) -> None: - self.state = "initial" - - def dumps(self) -> str: - return json.dumps({"state": self.state}) - - def loads(self, data: str) -> None: - payload = json.loads(data) - self.state = payload["state"] - - -class TestGraphRuntimeState: - def test_property_getters_and_setters(self): - # FIXME(-LAN-): Mock VariablePool if needed - variable_pool = VariablePool() - start_time = time() - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=start_time) - - # Test variable_pool property (read-only) - assert state.variable_pool == variable_pool - - # Test start_at property - assert state.start_at == start_time - new_time = time() + 100 - state.start_at = new_time - assert state.start_at == new_time - - # Test total_tokens property - assert state.total_tokens == 0 - state.total_tokens = 100 - assert state.total_tokens == 100 - - # Test node_run_steps property - assert state.node_run_steps == 0 - state.node_run_steps = 5 - assert state.node_run_steps == 5 - - def test_outputs_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting outputs returns a copy - outputs1 = state.outputs - outputs2 = state.outputs - assert outputs1 == outputs2 - assert outputs1 is not outputs2 # Different objects - - # Test that modifying retrieved outputs doesn't affect internal state - outputs = state.outputs - outputs["test"] = "value" - assert "test" not in state.outputs - - # Test set_output method - state.set_output("key1", "value1") - assert state.get_output("key1") == "value1" - - # Test update_outputs method - state.update_outputs({"key2": "value2", "key3": "value3"}) - assert state.get_output("key2") == "value2" - assert state.get_output("key3") == "value3" - - def test_llm_usage_immutability(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test that getting llm_usage returns a copy - usage1 = state.llm_usage - usage2 = state.llm_usage - assert usage1 is not usage2 # Different objects - - def test_type_validation(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test total_tokens validation - with pytest.raises(ValueError): - state.total_tokens = -1 - - # Test node_run_steps validation - with pytest.raises(ValueError): - state.node_run_steps = -1 - - def test_helper_methods(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - # Test increment_node_run_steps - initial_steps = state.node_run_steps - state.increment_node_run_steps() - assert state.node_run_steps == initial_steps + 1 - - # Test add_tokens - initial_tokens = state.total_tokens - state.add_tokens(50) - assert state.total_tokens == initial_tokens + 50 - - # Test add_tokens validation - with pytest.raises(ValueError): - state.add_tokens(-1) - - def test_ready_queue_default_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - queue = state.ready_queue - - from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue - - assert isinstance(queue, InMemoryReadyQueue) - - def test_graph_execution_lazy_instantiation(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - - execution = state.graph_execution - - from dify_graph.graph_engine.domain.graph_execution import GraphExecution - - assert isinstance(execution, GraphExecution) - assert execution.workflow_id == "" - assert state.graph_execution is execution - - def test_response_coordinator_configuration(self): - variable_pool = VariablePool() - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - - with pytest.raises(ValueError): - _ = state.response_coordinator - - mock_graph = MagicMock() - with patch( - "dify_graph.graph_engine.response_coordinator.ResponseStreamCoordinator", autospec=True - ) as coordinator_cls: - coordinator_instance = coordinator_cls.return_value - state.configure(graph=mock_graph) - - assert state.response_coordinator is coordinator_instance - coordinator_cls.assert_called_once_with(variable_pool=variable_pool, graph=mock_graph) - - # Configure again with same graph should be idempotent - state.configure(graph=mock_graph) - - other_graph = MagicMock() - with pytest.raises(ValueError): - state.attach_graph(other_graph) - - def test_read_only_wrapper_exposes_additional_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.configure() - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - assert wrapper.ready_queue_size == 0 - assert wrapper.exceptions_count == 0 - - def test_read_only_wrapper_serializes_runtime_state(self): - state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time()) - state.total_tokens = 5 - state.set_output("result", {"success": True}) - state.ready_queue.put("node-1") - - wrapper = ReadOnlyGraphRuntimeStateWrapper(state) - - wrapper_snapshot = json.loads(wrapper.dumps()) - state_snapshot = json.loads(state.dumps()) - - assert wrapper_snapshot == state_snapshot - - def test_dumps_and_loads_roundtrip_with_response_coordinator(self): - variable_pool = VariablePool() - variable_pool.add(("node1", "value"), "payload") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 10 - state.node_run_steps = 3 - state.set_output("final", {"result": True}) - usage = LLMUsage.from_metadata( - { - "prompt_tokens": 2, - "completion_tokens": 3, - "total_tokens": 5, - "total_price": "1.23", - "currency": "USD", - "latency": 0.5, - } - ) - state.llm_usage = usage - state.ready_queue.put("node-A") - - graph_execution = state.graph_execution - graph_execution.workflow_id = "wf-123" - graph_execution.exceptions_count = 4 - graph_execution.started = True - - mock_graph = MagicMock() - stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=stub, autospec=True): - state.attach_graph(mock_graph) - - stub.state = "configured" - - snapshot = state.dumps() - - restored = GraphRuntimeState.from_snapshot(snapshot) - - assert restored.total_tokens == 10 - assert restored.node_run_steps == 3 - assert restored.get_output("final") == {"result": True} - assert restored.llm_usage.total_tokens == usage.total_tokens - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-A" - - restored_segment = restored.variable_pool.get(("node1", "value")) - assert restored_segment is not None - assert restored_segment.value == "payload" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-123" - assert restored_execution.exceptions_count == 4 - assert restored_execution.started is True - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored.attach_graph(mock_graph) - - assert new_stub.state == "configured" - - def test_loads_rehydrates_existing_instance(self): - variable_pool = VariablePool() - variable_pool.add(("node", "key"), "value") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - state.total_tokens = 7 - state.node_run_steps = 2 - state.set_output("foo", "bar") - state.ready_queue.put("node-1") - - execution = state.graph_execution - execution.workflow_id = "wf-456" - execution.started = True - - mock_graph = MagicMock() - original_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=original_stub, autospec=True): - state.attach_graph(mock_graph) - - original_stub.state = "configured" - snapshot = state.dumps() - - new_stub = StubCoordinator() - with patch.object(GraphRuntimeState, "_build_response_coordinator", return_value=new_stub, autospec=True): - restored = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0) - restored.attach_graph(mock_graph) - restored.loads(snapshot) - - assert restored.total_tokens == 7 - assert restored.node_run_steps == 2 - assert restored.get_output("foo") == "bar" - assert restored.ready_queue.qsize() == 1 - assert restored.ready_queue.get(timeout=0.01) == "node-1" - - restored_segment = restored.variable_pool.get(("node", "key")) - assert restored_segment is not None - assert restored_segment.value == "value" - - restored_execution = restored.graph_execution - assert restored_execution.workflow_id == "wf-456" - assert restored_execution.started is True - - assert new_stub.state == "configured" - - def test_snapshot_restore_preserves_updated_conversation_variable(self): - variable_pool = VariablePool( - conversation_variables=[StringVariable(name="session_name", value="before")], - ) - variable_pool.add((CONVERSATION_VARIABLE_NODE_ID, "session_name"), "after") - - state = GraphRuntimeState(variable_pool=variable_pool, start_at=time()) - snapshot = state.dumps() - restored = GraphRuntimeState.from_snapshot(snapshot) - - restored_value = restored.variable_pool.get((CONVERSATION_VARIABLE_NODE_ID, "session_name")) - assert restored_value is not None - assert restored_value.value == "after" diff --git a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py b/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py deleted file mode 100644 index 158f7018b5..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_pause_reason.py +++ /dev/null @@ -1,88 +0,0 @@ -""" -Tests for PauseReason discriminated union serialization/deserialization. -""" - -import pytest -from pydantic import BaseModel, ValidationError - -from dify_graph.entities.pause_reason import ( - HumanInputRequired, - PauseReason, - SchedulingPause, -) - - -class _Holder(BaseModel): - """Helper model that embeds PauseReason for union tests.""" - - reason: PauseReason - - -class TestPauseReasonDiscriminator: - """Test suite for PauseReason union discriminator.""" - - @pytest.mark.parametrize( - ("dict_value", "expected"), - [ - pytest.param( - { - "reason": { - "TYPE": "human_input_required", - "form_id": "form_id", - "form_content": "form_content", - "node_id": "node_id", - "node_title": "node_title", - }, - }, - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - id="HumanInputRequired", - ), - pytest.param( - { - "reason": { - "TYPE": "scheduled_pause", - "message": "Hold on", - } - }, - SchedulingPause(message="Hold on"), - id="SchedulingPause", - ), - ], - ) - def test_model_validate(self, dict_value, expected): - """Ensure scheduled pause payloads with lowercase TYPE deserialize.""" - holder = _Holder.model_validate(dict_value) - - assert type(holder.reason) == type(expected) - assert holder.reason == expected - - @pytest.mark.parametrize( - "reason", - [ - HumanInputRequired( - form_id="form_id", - form_content="form_content", - node_id="node_id", - node_title="node_title", - ), - SchedulingPause(message="Hold on"), - ], - ids=lambda x: type(x).__name__, - ) - def test_model_construct(self, reason): - holder = _Holder(reason=reason) - assert holder.reason == reason - - def test_model_construct_with_invalid_type(self): - with pytest.raises(ValidationError): - holder = _Holder(reason=object()) # type: ignore - - def test_unknown_type_fails_validation(self): - """Unknown TYPE values should raise a validation error.""" - with pytest.raises(ValidationError): - _Holder.model_validate({"reason": {"TYPE": "UNKNOWN"}}) diff --git a/api/tests/unit_tests/core/workflow/entities/test_template.py b/api/tests/unit_tests/core/workflow/entities/test_template.py deleted file mode 100644 index 2d4c7f7b77..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_template.py +++ /dev/null @@ -1,87 +0,0 @@ -"""Tests for template module.""" - -from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment - - -class TestTemplate: - """Test Template class functionality.""" - - def test_from_answer_template_simple(self): - """Test parsing a simple answer template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == "!" - - def test_from_answer_template_multiple_vars(self): - """Test parsing an answer template with multiple variables.""" - template_str = "Hello {{#node1.name#}}, your age is {{#node2.age#}}." - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 5 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello " - assert isinstance(template.segments[1], VariableSegment) - assert template.segments[1].selector == ["node1", "name"] - assert isinstance(template.segments[2], TextSegment) - assert template.segments[2].text == ", your age is " - assert isinstance(template.segments[3], VariableSegment) - assert template.segments[3].selector == ["node2", "age"] - assert isinstance(template.segments[4], TextSegment) - assert template.segments[4].text == "." - - def test_from_answer_template_no_vars(self): - """Test parsing an answer template with no variables.""" - template_str = "Hello, world!" - template = Template.from_answer_template(template_str) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], TextSegment) - assert template.segments[0].text == "Hello, world!" - - def test_from_end_outputs_single(self): - """Test creating template from End node outputs with single variable.""" - outputs_config = [{"variable": "text", "value_selector": ["node1", "text"]}] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 1 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - - def test_from_end_outputs_multiple(self): - """Test creating template from End node outputs with multiple variables.""" - outputs_config = [ - {"variable": "text", "value_selector": ["node1", "text"]}, - {"variable": "result", "value_selector": ["node2", "result"]}, - ] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 3 - assert isinstance(template.segments[0], VariableSegment) - assert template.segments[0].selector == ["node1", "text"] - assert template.segments[0].variable_name == "text" - assert isinstance(template.segments[1], TextSegment) - assert template.segments[1].text == "\n" - assert isinstance(template.segments[2], VariableSegment) - assert template.segments[2].selector == ["node2", "result"] - assert template.segments[2].variable_name == "result" - - def test_from_end_outputs_empty(self): - """Test creating template from empty End node outputs.""" - outputs_config = [] - template = Template.from_end_outputs(outputs_config) - - assert len(template.segments) == 0 - - def test_template_str_representation(self): - """Test string representation of template.""" - template_str = "Hello, {{#node1.name#}}!" - template = Template.from_answer_template(template_str) - - assert str(template) == template_str diff --git a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py b/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py deleted file mode 100644 index 6100ebede5..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_variable_pool.py +++ /dev/null @@ -1,136 +0,0 @@ -from dify_graph.runtime import VariablePool -from dify_graph.variables.segments import ( - BooleanSegment, - IntegerSegment, - NoneSegment, - StringSegment, -) - - -class TestVariablePoolGetAndNestedAttribute: - # - # _get_nested_attribute tests - # - def test__get_nested_attribute_existing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert segment.value == 123 - - def test__get_nested_attribute_missing_key(self): - pool = VariablePool.empty() - obj = {"a": 123} - segment = pool._get_nested_attribute(obj, "b") - assert segment is None - - def test__get_nested_attribute_non_dict(self): - pool = VariablePool.empty() - obj = ["not", "a", "dict"] - segment = pool._get_nested_attribute(obj, "a") - assert segment is None - - def test__get_nested_attribute_with_none_value(self): - pool = VariablePool.empty() - obj = {"a": None} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, NoneSegment) - - def test__get_nested_attribute_with_empty_string(self): - pool = VariablePool.empty() - obj = {"a": ""} - segment = pool._get_nested_attribute(obj, "a") - assert segment is not None - assert isinstance(segment, StringSegment) - assert segment.value == "" - - # - # get tests - # - def test_get_simple_variable(self): - pool = VariablePool.empty() - pool.add(("node1", "var1"), "value1") - segment = pool.get(("node1", "var1")) - assert segment is not None - assert segment.value == "value1" - - def test_get_missing_variable(self): - pool = VariablePool.empty() - result = pool.get(("node1", "unknown")) - assert result is None - - def test_get_with_too_short_selector(self): - pool = VariablePool.empty() - result = pool.get(("only_node",)) - assert result is None - - def test_get_nested_object_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - # simulate selector with nested attr - segment = pool.get(("node1", "obj", "inner")) - assert segment is not None - assert segment.value == "hello" - - def test_get_nested_object_missing_attribute(self): - pool = VariablePool.empty() - obj_value = {"inner": "hello"} - pool.add(("node1", "obj"), obj_value) - - result = pool.get(("node1", "obj", "not_exist")) - assert result is None - - def test_get_nested_object_attribute_with_falsy_values(self): - pool = VariablePool.empty() - obj_value = { - "inner_none": None, - "inner_empty": "", - "inner_zero": 0, - "inner_false": False, - } - pool.add(("node1", "obj"), obj_value) - - segment_none = pool.get(("node1", "obj", "inner_none")) - assert segment_none is not None - assert isinstance(segment_none, NoneSegment) - - segment_empty = pool.get(("node1", "obj", "inner_empty")) - assert segment_empty is not None - assert isinstance(segment_empty, StringSegment) - assert segment_empty.value == "" - - segment_zero = pool.get(("node1", "obj", "inner_zero")) - assert segment_zero is not None - assert isinstance(segment_zero, IntegerSegment) - assert segment_zero.value == 0 - - segment_false = pool.get(("node1", "obj", "inner_false")) - assert segment_false is not None - assert isinstance(segment_false, BooleanSegment) - assert segment_false.value is False - - -class TestVariablePoolGetNotModifyVariableDictionary: - _NODE_ID = "start" - _VAR_NAME = "name" - - def test_convert_to_template_should_not_introduce_extra_keys(self): - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], 0) - pool.convert_template("The start.name is {{#start.name#}}") - assert "The start" not in pool.variable_dictionary - - def test_get_should_not_modify_variable_dictionary(self): - pool = VariablePool.empty() - pool.get([self._NODE_ID, self._VAR_NAME]) - assert len(pool.variable_dictionary) == 1 # only contains `sys` node id - assert "start" not in pool.variable_dictionary - - pool = VariablePool.empty() - pool.add([self._NODE_ID, self._VAR_NAME], "Joe") - pool.get([self._NODE_ID, "count"]) - start_subdict = pool.variable_dictionary[self._NODE_ID] - assert "count" not in start_subdict diff --git a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py b/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py deleted file mode 100644 index 216e64db8d..0000000000 --- a/api/tests/unit_tests/core/workflow/entities/test_workflow_node_execution.py +++ /dev/null @@ -1,225 +0,0 @@ -""" -Unit tests for WorkflowNodeExecution domain model, focusing on process_data truncation functionality. -""" - -from dataclasses import dataclass -from datetime import datetime -from typing import Any - -import pytest - -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecution -from dify_graph.enums import BuiltinNodeTypes - - -class TestWorkflowNodeExecutionProcessDataTruncation: - """Test process_data truncation functionality in WorkflowNodeExecution domain model.""" - - def create_workflow_node_execution( - self, - process_data: dict[str, Any] | None = None, - ) -> WorkflowNodeExecution: - """Create a WorkflowNodeExecution instance for testing.""" - return WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=process_data, - created_at=datetime.now(), - ) - - def test_initial_process_data_truncated_state(self): - """Test that process_data_truncated returns False initially.""" - execution = self.create_workflow_node_execution() - - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_set_and_get_truncated_process_data(self): - """Test setting and getting truncated process_data.""" - execution = self.create_workflow_node_execution() - test_truncated_data = {"truncated": True, "key": "value"} - - execution.set_truncated_process_data(test_truncated_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_truncated_data - - def test_set_truncated_process_data_to_none(self): - """Test setting truncated process_data to None.""" - execution = self.create_workflow_node_execution() - - # First set some data - execution.set_truncated_process_data({"key": "value"}) - assert execution.process_data_truncated is True - - # Then set to None - execution.set_truncated_process_data(None) - assert execution.process_data_truncated is False - assert execution.get_truncated_process_data() is None - - def test_get_response_process_data_with_no_truncation(self): - """Test get_response_process_data when no truncation is set.""" - original_data = {"original": True, "data": "value"} - execution = self.create_workflow_node_execution(process_data=original_data) - - response_data = execution.get_response_process_data() - - assert response_data == original_data - assert execution.process_data_truncated is False - - def test_get_response_process_data_with_truncation(self): - """Test get_response_process_data when truncation is set.""" - original_data = {"original": True, "large_data": "x" * 10000} - truncated_data = {"original": True, "large_data": "[TRUNCATED]"} - - execution = self.create_workflow_node_execution(process_data=original_data) - execution.set_truncated_process_data(truncated_data) - - response_data = execution.get_response_process_data() - - # Should return truncated data, not original - assert response_data == truncated_data - assert response_data != original_data - assert execution.process_data_truncated is True - - def test_get_response_process_data_with_none_process_data(self): - """Test get_response_process_data when process_data is None.""" - execution = self.create_workflow_node_execution(process_data=None) - - response_data = execution.get_response_process_data() - - assert response_data is None - assert execution.process_data_truncated is False - - def test_consistency_with_inputs_outputs_pattern(self): - """Test that process_data truncation follows the same pattern as inputs/outputs.""" - execution = self.create_workflow_node_execution() - - # Test that all truncation methods exist and behave consistently - test_data = {"test": "data"} - - # Test inputs truncation - execution.set_truncated_inputs(test_data) - assert execution.inputs_truncated is True - assert execution.get_truncated_inputs() == test_data - - # Test outputs truncation - execution.set_truncated_outputs(test_data) - assert execution.outputs_truncated is True - assert execution.get_truncated_outputs() == test_data - - # Test process_data truncation - execution.set_truncated_process_data(test_data) - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - - @pytest.mark.parametrize( - "test_data", - [ - {"simple": "value"}, - {"nested": {"key": "value"}}, - {"list": [1, 2, 3]}, - {"mixed": {"string": "value", "number": 42, "list": [1, 2]}}, - {}, # empty dict - ], - ) - def test_truncated_process_data_with_various_data_types(self, test_data): - """Test that truncated process_data works with various data types.""" - execution = self.create_workflow_node_execution() - - execution.set_truncated_process_data(test_data) - - assert execution.process_data_truncated is True - assert execution.get_truncated_process_data() == test_data - assert execution.get_response_process_data() == test_data - - -@dataclass -class ProcessDataScenario: - """Test scenario data for process_data functionality.""" - - name: str - original_data: dict[str, Any] | None - truncated_data: dict[str, Any] | None - expected_truncated_flag: bool - expected_response_data: dict[str, Any] | None - - -class TestWorkflowNodeExecutionProcessDataScenarios: - """Test various scenarios for process_data handling.""" - - def get_process_data_scenarios(self) -> list[ProcessDataScenario]: - """Create test scenarios for process_data functionality.""" - return [ - ProcessDataScenario( - name="no_process_data", - original_data=None, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data=None, - ), - ProcessDataScenario( - name="process_data_without_truncation", - original_data={"small": "data"}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={"small": "data"}, - ), - ProcessDataScenario( - name="process_data_with_truncation", - original_data={"large": "x" * 10000, "metadata": "info"}, - truncated_data={"large": "[TRUNCATED]", "metadata": "info"}, - expected_truncated_flag=True, - expected_response_data={"large": "[TRUNCATED]", "metadata": "info"}, - ), - ProcessDataScenario( - name="empty_process_data", - original_data={}, - truncated_data=None, - expected_truncated_flag=False, - expected_response_data={}, - ), - ProcessDataScenario( - name="complex_nested_data_with_truncation", - original_data={ - "config": {"setting": "value"}, - "logs": ["log1", "log2"] * 1000, # Large list - "status": "running", - }, - truncated_data={"config": {"setting": "value"}, "logs": "[TRUNCATED: 2000 items]", "status": "running"}, - expected_truncated_flag=True, - expected_response_data={ - "config": {"setting": "value"}, - "logs": "[TRUNCATED: 2000 items]", - "status": "running", - }, - ), - ] - - @pytest.mark.parametrize( - "scenario", - get_process_data_scenarios(None), - ids=[scenario.name for scenario in get_process_data_scenarios(None)], - ) - def test_process_data_scenarios(self, scenario: ProcessDataScenario): - """Test various process_data scenarios.""" - execution = WorkflowNodeExecution( - id="test-execution-id", - workflow_id="test-workflow-id", - index=1, - node_id="test-node-id", - node_type=BuiltinNodeTypes.LLM, - title="Test Node", - process_data=scenario.original_data, - created_at=datetime.now(), - ) - - if scenario.truncated_data is not None: - execution.set_truncated_process_data(scenario.truncated_data) - - assert execution.process_data_truncated == scenario.expected_truncated_flag - assert execution.get_response_process_data() == scenario.expected_response_data diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph.py b/api/tests/unit_tests/core/workflow/graph/test_graph.py deleted file mode 100644 index 24bd9ccbed..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph.py +++ /dev/null @@ -1,281 +0,0 @@ -"""Unit tests for Graph class methods.""" - -from unittest.mock import Mock - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from dify_graph.graph.edge import Edge -from dify_graph.graph.graph import Graph -from dify_graph.nodes.base.node import Node - - -def create_mock_node(node_id: str, execution_type: NodeExecutionType, state: NodeState = NodeState.UNKNOWN) -> Node: - """Create a mock node for testing.""" - node = Mock(spec=Node) - node.id = node_id - node.execution_type = execution_type - node.state = state - node.node_type = BuiltinNodeTypes.START - return node - - -class TestMarkInactiveRootBranches: - """Test cases for _mark_inactive_root_branches method.""" - - def test_single_root_no_marking(self): - """Test that single root graph doesn't mark anything as skipped.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - } - - in_edges = {"child1": ["edge1"]} - out_edges = {"root1": ["edge1"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["child1"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - - def test_multiple_roots_mark_inactive(self): - """Test marking inactive root branches with multiple root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "root2": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - - def test_shared_downstream_node(self): - """Test that shared downstream nodes are not skipped if at least one path is active.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "shared": create_mock_node("shared", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="child1", head="shared", source_handle="source"), - "edge4": Edge(id="edge4", tail="child2", head="shared", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "shared": ["edge3", "edge4"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "child1": ["edge3"], - "child2": ["edge4"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.SKIPPED - assert nodes["shared"].state == NodeState.UNKNOWN # Not skipped because edge3 is active - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - - def test_deep_branch_marking(self): - """Test marking deep branches with multiple levels.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "level1_a": create_mock_node("level1_a", NodeExecutionType.EXECUTABLE), - "level1_b": create_mock_node("level1_b", NodeExecutionType.EXECUTABLE), - "level2_a": create_mock_node("level2_a", NodeExecutionType.EXECUTABLE), - "level2_b": create_mock_node("level2_b", NodeExecutionType.EXECUTABLE), - "level3": create_mock_node("level3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="level1_a", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="level1_b", source_handle="source"), - "edge3": Edge(id="edge3", tail="level1_a", head="level2_a", source_handle="source"), - "edge4": Edge(id="edge4", tail="level1_b", head="level2_b", source_handle="source"), - "edge5": Edge(id="edge5", tail="level2_b", head="level3", source_handle="source"), - } - - in_edges = { - "level1_a": ["edge1"], - "level1_b": ["edge2"], - "level2_a": ["edge3"], - "level2_b": ["edge4"], - "level3": ["edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "level1_a": ["edge3"], - "level1_b": ["edge4"], - "level2_b": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["level1_a"].state == NodeState.UNKNOWN - assert nodes["level1_b"].state == NodeState.SKIPPED - assert nodes["level2_a"].state == NodeState.UNKNOWN - assert nodes["level2_b"].state == NodeState.SKIPPED - assert nodes["level3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.UNKNOWN - assert edges["edge4"].state == NodeState.SKIPPED - assert edges["edge5"].state == NodeState.SKIPPED - - def test_non_root_execution_type(self): - """Test that nodes with non-ROOT execution type are not treated as root nodes.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "non_root": create_mock_node("non_root", NodeExecutionType.EXECUTABLE), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="non_root", head="child2", source_handle="source"), - } - - in_edges = {"child1": ["edge1"], "child2": ["edge2"]} - out_edges = {"root1": ["edge1"], "non_root": ["edge2"]} - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["non_root"].state == NodeState.UNKNOWN # Not marked as skipped - assert nodes["child1"].state == NodeState.UNKNOWN - assert nodes["child2"].state == NodeState.UNKNOWN - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.UNKNOWN - - def test_empty_graph(self): - """Test handling of empty graph structures.""" - nodes = {} - edges = {} - in_edges = {} - out_edges = {} - - # Should not raise any errors - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "non_existent") - - def test_three_roots_mark_two_inactive(self): - """Test with three root nodes where two should be marked inactive.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "child1": create_mock_node("child1", NodeExecutionType.EXECUTABLE), - "child2": create_mock_node("child2", NodeExecutionType.EXECUTABLE), - "child3": create_mock_node("child3", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="child1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="child2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="child3", source_handle="source"), - } - - in_edges = { - "child1": ["edge1"], - "child2": ["edge2"], - "child3": ["edge3"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root2") - - assert nodes["root1"].state == NodeState.SKIPPED - assert nodes["root2"].state == NodeState.UNKNOWN # Active root - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["child1"].state == NodeState.SKIPPED - assert nodes["child2"].state == NodeState.UNKNOWN - assert nodes["child3"].state == NodeState.SKIPPED - assert edges["edge1"].state == NodeState.SKIPPED - assert edges["edge2"].state == NodeState.UNKNOWN - assert edges["edge3"].state == NodeState.SKIPPED - - def test_convergent_paths(self): - """Test convergent paths where multiple inactive branches lead to same node.""" - nodes = { - "root1": create_mock_node("root1", NodeExecutionType.ROOT), - "root2": create_mock_node("root2", NodeExecutionType.ROOT), - "root3": create_mock_node("root3", NodeExecutionType.ROOT), - "mid1": create_mock_node("mid1", NodeExecutionType.EXECUTABLE), - "mid2": create_mock_node("mid2", NodeExecutionType.EXECUTABLE), - "convergent": create_mock_node("convergent", NodeExecutionType.EXECUTABLE), - } - - edges = { - "edge1": Edge(id="edge1", tail="root1", head="mid1", source_handle="source"), - "edge2": Edge(id="edge2", tail="root2", head="mid2", source_handle="source"), - "edge3": Edge(id="edge3", tail="root3", head="convergent", source_handle="source"), - "edge4": Edge(id="edge4", tail="mid1", head="convergent", source_handle="source"), - "edge5": Edge(id="edge5", tail="mid2", head="convergent", source_handle="source"), - } - - in_edges = { - "mid1": ["edge1"], - "mid2": ["edge2"], - "convergent": ["edge3", "edge4", "edge5"], - } - out_edges = { - "root1": ["edge1"], - "root2": ["edge2"], - "root3": ["edge3"], - "mid1": ["edge4"], - "mid2": ["edge5"], - } - - Graph._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, "root1") - - assert nodes["root1"].state == NodeState.UNKNOWN - assert nodes["root2"].state == NodeState.SKIPPED - assert nodes["root3"].state == NodeState.SKIPPED - assert nodes["mid1"].state == NodeState.UNKNOWN - assert nodes["mid2"].state == NodeState.SKIPPED - assert nodes["convergent"].state == NodeState.UNKNOWN # Not skipped due to active path from root1 - assert edges["edge1"].state == NodeState.UNKNOWN - assert edges["edge2"].state == NodeState.SKIPPED - assert edges["edge3"].state == NodeState.SKIPPED - assert edges["edge4"].state == NodeState.UNKNOWN - assert edges["edge5"].state == NodeState.SKIPPED diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py b/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py deleted file mode 100644 index 64c2eee776..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_builder.py +++ /dev/null @@ -1,59 +0,0 @@ -from unittest.mock import MagicMock - -import pytest - -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.graph import Graph -from dify_graph.nodes.base.node import Node - - -def _make_node(node_id: str, node_type: NodeType = BuiltinNodeTypes.START) -> Node: - node = MagicMock(spec=Node) - node.id = node_id - node.node_type = node_type - node.execution_type = None # attribute not used in builder path - return node - - -def test_graph_builder_creates_linear_graph(): - builder = Graph.new() - root = _make_node("root", BuiltinNodeTypes.START) - mid = _make_node("mid", BuiltinNodeTypes.LLM) - end = _make_node("end", BuiltinNodeTypes.END) - - graph = builder.add_root(root).add_node(mid).add_node(end).build() - - assert graph.root_node is root - assert graph.nodes == {"root": root, "mid": mid, "end": end} - assert len(graph.edges) == 2 - first_edge = next(iter(graph.edges.values())) - assert first_edge.tail == "root" - assert first_edge.head == "mid" - assert graph.out_edges["mid"] == [edge_id for edge_id, edge in graph.edges.items() if edge.tail == "mid"] - - -def test_graph_builder_supports_custom_predecessor(): - builder = Graph.new() - root = _make_node("root") - branch = _make_node("branch") - other = _make_node("other") - - graph = builder.add_root(root).add_node(branch).add_node(other, from_node_id="root").build() - - outgoing_root = graph.out_edges["root"] - assert len(outgoing_root) == 2 - edge_targets = {graph.edges[eid].head for eid in outgoing_root} - assert edge_targets == {"branch", "other"} - - -def test_graph_builder_validates_usage(): - builder = Graph.new() - node = _make_node("node") - - with pytest.raises(ValueError, match="Root node"): - builder.add_node(node) - - builder.add_root(node) - duplicate = _make_node("node") - with pytest.raises(ValueError, match="Duplicate"): - builder.add_node(duplicate) diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py deleted file mode 100644 index 75de07bd8b..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_skip_validation.py +++ /dev/null @@ -1,118 +0,0 @@ -from __future__ import annotations - -from typing import Any - -import pytest - -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.graph import Graph -from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes import BuiltinNodeTypes -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_iteration_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": [node_id, "output"], - }, - } - ], - "edges": [], - } - - -def _build_loop_graph(node_id: str) -> dict[str, Any]: - return { - "nodes": [ - { - "id": node_id, - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - "loop_variables": [], - "outputs": {}, - }, - } - ], - "edges": [], - } - - -def _make_factory(graph_config: dict[str, Any]) -> DifyNodeFactory: - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - environment_variables=[], - ), - start_at=0.0, - ) - return DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) - - -def test_iteration_root_requires_skip_validation(): - node_id = "iteration-node" - graph_config = _build_iteration_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.ITERATION - - -def test_loop_root_requires_skip_validation(): - node_id = "loop-node" - graph_config = _build_loop_graph(node_id) - node_factory = _make_factory(graph_config) - - with pytest.raises(GraphValidationError): - Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - ) - - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=node_id, - skip_validation=True, - ) - - assert graph.root_node.id == node_id - assert graph.root_node.node_type == BuiltinNodeTypes.LOOP diff --git a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py b/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py deleted file mode 100644 index e94ad74eb0..0000000000 --- a/api/tests/unit_tests/core/workflow/graph/test_graph_validation.py +++ /dev/null @@ -1,219 +0,0 @@ -from __future__ import annotations - -import time -from collections.abc import Mapping -from dataclasses import dataclass - -import pytest - -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, NodeExecutionType, NodeType -from dify_graph.graph import Graph -from dify_graph.graph.validation import GraphValidationError -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params - - -class _TestNodeData(BaseNodeData): - type: NodeType | None = None - execution_type: NodeExecutionType | str | None = None - - -class _TestNode(Node[_TestNodeData]): - node_type = BuiltinNodeTypes.ANSWER - execution_type = NodeExecutionType.EXECUTABLE - - @classmethod - def version(cls) -> str: - return "1" - - def __init__( - self, - *, - id: str, - config: Mapping[str, object], - graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - ) -> None: - super().__init__( - id=id, - config=config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - node_type_value = self.data.get("type") - if isinstance(node_type_value, str): - self.node_type = node_type_value - - def _run(self): - raise NotImplementedError - - def post_init(self) -> None: - super().post_init() - self._maybe_override_execution_type() - self.data = dict(self.node_data.model_dump()) - - def _maybe_override_execution_type(self) -> None: - execution_type_value = self.node_data.execution_type - if execution_type_value is None: - return - if isinstance(execution_type_value, NodeExecutionType): - self.execution_type = execution_type_value - else: - self.execution_type = NodeExecutionType(execution_type_value) - - -@dataclass(slots=True) -class _SimpleNodeFactory: - graph_init_params: GraphInitParams - graph_runtime_state: GraphRuntimeState - - def create_node(self, node_config: Mapping[str, object]) -> _TestNode: - node_id = str(node_config["id"]) - node = _TestNode( - id=node_id, - config=node_config, - graph_init_params=self.graph_init_params, - graph_runtime_state=self.graph_runtime_state, - ) - return node - - -@pytest.fixture -def graph_init_dependencies() -> tuple[_SimpleNodeFactory, dict[str, object]]: - graph_config: dict[str, object] = {"edges": [], "nodes": []} - init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}) - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - factory = _SimpleNodeFactory(graph_init_params=init_params, graph_runtime_state=runtime_state) - return factory, graph_config - - -def test_graph_initialization_runs_default_validators( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -): - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - - -def test_graph_validation_fails_for_unknown_edge_targets( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "missing", "sourceHandle": "success"}, - ] - - with pytest.raises(GraphValidationError) as exc: - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert any(issue.code == "MISSING_NODE" for issue in exc.value.issues) - - -def test_graph_promotes_fail_branch_nodes_to_branch_execution_type( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - { - "id": "branch", - "data": { - "type": BuiltinNodeTypes.IF_ELSE, - "title": "Branch", - "error_strategy": ErrorStrategy.FAIL_BRANCH, - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "branch", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.nodes["branch"].execution_type == NodeExecutionType.BRANCH - - -def test_graph_init_ignores_custom_note_nodes_before_node_data_validation( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - {"id": "answer", "data": {"type": BuiltinNodeTypes.ANSWER, "title": "Answer"}}, - { - "id": "note", - "type": "custom-note", - "data": { - "type": "", - "title": "", - "desc": "", - "text": "{}", - "theme": "blue", - }, - }, - ] - graph_config["edges"] = [ - {"source": "start", "target": "answer", "sourceHandle": "success"}, - ] - - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - assert graph.root_node.id == "start" - assert "answer" in graph.nodes - assert "note" not in graph.nodes - - -def test_graph_init_fails_for_unknown_root_node_id( - graph_init_dependencies: tuple[_SimpleNodeFactory, dict[str, object]], -) -> None: - node_factory, graph_config = graph_init_dependencies - graph_config["nodes"] = [ - { - "id": "start", - "data": {"type": BuiltinNodeTypes.START, "title": "Start", "execution_type": NodeExecutionType.ROOT}, - }, - ] - graph_config["edges"] = [] - - with pytest.raises(ValueError, match="Root node id missing not found in the graph"): - Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="missing") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/README.md b/api/tests/unit_tests/core/workflow/graph_engine/README.md index 40ed61eb02..dd419f0810 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/README.md +++ b/api/tests/unit_tests/core/workflow/graph_engine/README.md @@ -1,441 +1,30 @@ -# Graph Engine Testing Framework +# Workflow Graph Engine Smoke Tests -## Overview +This directory now keeps only a small Dify-owned smoke layer around the external +`graphon` package. -This directory contains a comprehensive testing framework for the Graph Engine, including: +Retained coverage focuses on: -1. **TableTestRunner** - Advanced table-driven test framework for workflow testing -1. **Auto-Mock System** - Powerful mocking framework for testing without external dependencies +1. Dify workflow layers: + - `layers/test_llm_quota.py` + - `layers/test_observability.py` +2. Human-input resume integration: + - `test_parallel_human_input_join_resume.py` +3. One mocked tool/chatflow smoke path: + - `test_tool_in_chatflow.py` -## TableTestRunner Framework +The helper modules below remain only because the retained smoke tests use them: -The TableTestRunner (`test_table_runner.py`) provides a robust table-driven testing framework for GraphEngine workflows. +1. `test_mock_config.py` +2. `test_mock_factory.py` +3. `test_mock_nodes.py` +4. `test_table_runner.py` -### Features - -- **Table-driven testing** - Define test cases as structured data -- **Parallel test execution** - Run tests concurrently for faster execution -- **Property-based testing** - Integration with Hypothesis for fuzzing -- **Event sequence validation** - Verify correct event ordering -- **Mock configuration** - Seamless integration with the auto-mock system -- **Performance metrics** - Track execution times and bottlenecks -- **Detailed error reporting** - Comprehensive failure diagnostics - -### Basic Usage - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase - -# Create test runner -runner = TableTestRunner() - -# Define test case -test_case = WorkflowTestCase( - fixture_path="simple_workflow", - inputs={"query": "Hello"}, - expected_outputs={"result": "World"}, - description="Basic workflow test", -) - -# Run single test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Advanced Features - -#### Parallel Execution - -```python -runner = TableTestRunner(max_workers=8) - -test_cases = [ - WorkflowTestCase(...), - WorkflowTestCase(...), - # ... more test cases -] - -# Run tests in parallel -suite_result = runner.run_table_tests( - test_cases, - parallel=True, - fail_fast=False -) - -print(f"Success rate: {suite_result.success_rate:.1f}%") -``` - -#### Event Sequence Validation - -```python -from dify_graph.graph_events import ( - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, -) - -test_case = WorkflowTestCase( - fixture_path="workflow", - inputs={}, - expected_outputs={}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] -) -``` - -### Test Suite Reports - -```python -# Run test suite -suite_result = runner.run_table_tests(test_cases) - -# Generate detailed report -report = runner.generate_report(suite_result) -print(report) - -# Access specific results -failed_results = suite_result.get_failed_results() -for result in failed_results: - print(f"Failed: {result.test_case.description}") - print(f" Error: {result.error}") -``` - -### Performance Testing - -```python -# Enable logging for performance insights -runner = TableTestRunner( - enable_logging=True, - log_level="DEBUG" -) - -# Run tests and analyze performance -suite_result = runner.run_table_tests(test_cases) - -# Get slowest tests -sorted_results = sorted( - suite_result.results, - key=lambda r: r.execution_time, - reverse=True -) - -print("Slowest tests:") -for result in sorted_results[:5]: - print(f" {result.test_case.description}: {result.execution_time:.2f}s") -``` - -## Integration: TableTestRunner + Auto-Mock System - -The TableTestRunner seamlessly integrates with the auto-mock system for comprehensive workflow testing: - -```python -from test_table_runner import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Configure mocks -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .with_tool_response({"result": "mocked"}) - .with_delays(True) # Simulate realistic delays - .build()) - -# Create test case with mocking -test_case = WorkflowTestCase( - fixture_path="complex_workflow", - inputs={"query": "test"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - description="Test with mocked services", -) - -# Run test -runner = TableTestRunner() -result = runner.run_test_case(test_case) -``` - -## Auto-Mock System - -The auto-mock system provides a powerful framework for testing workflows that contain nodes requiring third-party services (LLM, APIs, tools, etc.) without making actual external calls. This enables: - -- **Fast test execution** - No network latency or API rate limits -- **Deterministic results** - Consistent outputs for reliable testing -- **Cost savings** - No API usage charges during testing -- **Offline testing** - Tests can run without internet connectivity -- **Error simulation** - Test error handling without triggering real failures - -## Architecture - -The auto-mock system consists of three main components: - -### 1. MockNodeFactory (`test_mock_factory.py`) - -- Extends `DifyNodeFactory` to intercept node creation -- Automatically detects nodes requiring third-party services -- Returns mock node implementations instead of real ones -- Supports registration of custom mock implementations - -### 2. Mock Node Implementations (`test_mock_nodes.py`) - -- `MockLLMNode` - Mocks LLM API calls (OpenAI, Anthropic, etc.) -- `MockAgentNode` - Mocks agent execution -- `MockToolNode` - Mocks tool invocations -- `MockKnowledgeRetrievalNode` - Mocks knowledge base queries -- `MockHttpRequestNode` - Mocks HTTP requests -- `MockParameterExtractorNode` - Mocks parameter extraction -- `MockDocumentExtractorNode` - Mocks document processing -- `MockQuestionClassifierNode` - Mocks question classification - -### 3. Mock Configuration (`test_mock_config.py`) - -- `MockConfig` - Global configuration for mock behavior -- `NodeMockConfig` - Node-specific mock configuration -- `MockConfigBuilder` - Fluent interface for building configurations - -## Usage - -### Basic Example - -```python -from test_graph_engine import TableTestRunner, WorkflowTestCase -from test_mock_config import MockConfigBuilder - -# Create test runner -runner = TableTestRunner() - -# Configure mock responses -mock_config = (MockConfigBuilder() - .with_llm_response("Mocked LLM response") - .build()) - -# Define test case -test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello"}, - expected_outputs={"answer": "Mocked LLM response"}, - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, -) - -# Run test -result = runner.run_test_case(test_case) -assert result.success -``` - -### Custom Node Outputs - -```python -# Configure specific outputs for individual nodes -mock_config = MockConfig() -mock_config.set_node_outputs("llm_node_123", { - "text": "Custom response for this specific node", - "usage": {"total_tokens": 50}, - "finish_reason": "stop", -}) -``` - -### Error Simulation - -```python -# Simulate node failures for error handling tests -mock_config = MockConfig() -mock_config.set_node_error("http_node", "Connection timeout") -``` - -### Simulated Delays - -```python -# Add realistic execution delays -from test_mock_config import NodeMockConfig - -node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response"}, - delay=1.5, # 1.5 second delay -) -mock_config.set_node_config("llm_node", node_config) -``` - -### Custom Handlers - -```python -# Define custom logic for mock outputs -def custom_handler(node): - # Access node state and return dynamic outputs - return { - "text": f"Processed: {node.graph_runtime_state.variable_pool.get('query')}", - } - -node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_handler, -) -``` - -## Node Types Automatically Mocked - -The following node types are automatically mocked when `use_auto_mock=True`: - -- `LLM` - Language model nodes -- `AGENT` - Agent execution nodes -- `TOOL` - Tool invocation nodes -- `KNOWLEDGE_RETRIEVAL` - Knowledge base query nodes -- `HTTP_REQUEST` - HTTP request nodes -- `PARAMETER_EXTRACTOR` - Parameter extraction nodes -- `DOCUMENT_EXTRACTOR` - Document processing nodes -- `QUESTION_CLASSIFIER` - Question classification nodes - -## Advanced Features - -### Registering Custom Mock Implementations - -```python -from test_mock_factory import MockNodeFactory - -# Create custom mock implementation -class CustomMockNode(BaseNode): - def _run(self): - # Custom mock logic - pass - -# Register for a specific node type -factory = MockNodeFactory(...) -factory.register_mock_node_type(NodeType.CUSTOM, CustomMockNode) -``` - -### Default Configurations by Node Type - -```python -# Set defaults for all nodes of a specific type -mock_config.set_default_config(NodeType.LLM, { - "temperature": 0.7, - "max_tokens": 100, -}) -``` - -### MockConfigBuilder Fluent API - -```python -config = (MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"result": "data"}) - .with_retrieval_response("Retrieved content") - .with_http_response({"status_code": 200, "body": "{}"}) - .with_node_output("node_id", {"output": "value"}) - .with_node_error("error_node", "Error message") - .with_delays(True) - .build()) -``` - -## Testing Workflows - -### 1. Create Workflow Fixture - -Create a YAML fixture file in `api/tests/fixtures/workflow/` directory defining your workflow graph. - -### 2. Configure Mocks - -Set up mock configurations for nodes that need third-party services. - -### 3. Define Test Cases - -Create `WorkflowTestCase` instances with inputs, expected outputs, and mock config. - -### 4. Run Tests - -Use `TableTestRunner` to execute test cases and validate results. - -## Best Practices - -1. **Use descriptive mock responses** - Make it clear in outputs that they are mocked -1. **Test both success and failure paths** - Use error simulation to test error handling -1. **Keep mock configs close to tests** - Define mocks in the same test file for clarity -1. **Use custom handlers sparingly** - Only when dynamic behavior is needed -1. **Document mock behavior** - Comment why specific mock values are chosen -1. **Validate mock accuracy** - Ensure mocks reflect real service behavior - -## Examples - -See `test_mock_example.py` for comprehensive examples including: - -- Basic LLM workflow testing -- Custom node outputs -- HTTP and tool workflow testing -- Error simulation -- Performance testing with delays - -## Running Tests - -### TableTestRunner Tests +Examples: ```bash -# Run graph engine tests (includes property-based tests) -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py - -# Run with specific test patterns -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -k "test_echo" - -# Run with verbose output -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_graph_engine.py -v +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +uv run --project api --dev pytest api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py ``` - -### Mock System Tests - -```bash -# Run auto-mock system tests -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/test_auto_mock_system.py - -# Run examples -uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_example.py - -# Run simple validation -uv run python api/tests/unit_tests/dify_graph/graph_engine/test_mock_simple.py -``` - -### All Tests - -```bash -# Run all graph engine tests -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ - -# Run with coverage -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ --cov=dify_graph.graph_engine - -# Run in parallel -uv run pytest api/tests/unit_tests/dify_graph/graph_engine/ -n auto -``` - -## Troubleshooting - -### Issue: Mock not being applied - -- Ensure `use_auto_mock=True` in `WorkflowTestCase` -- Verify node ID matches in mock config -- Check that node type is in the auto-mock list - -### Issue: Unexpected outputs - -- Debug by printing `result.actual_outputs` -- Check if custom handler is overriding expected outputs -- Verify mock config is properly built - -### Issue: Import errors - -- Ensure all mock modules are in the correct path -- Check that required dependencies are installed - -## Future Enhancements - -Potential improvements to the auto-mock system: - -1. **Recording and playback** - Record real API responses for replay in tests -1. **Mock templates** - Pre-defined mock configurations for common scenarios -1. **Async support** - Better support for async node execution -1. **Mock validation** - Validate mock outputs against node schemas -1. **Performance profiling** - Built-in performance metrics for mocked workflows diff --git a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py b/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py deleted file mode 100644 index 4dec618e49..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/command_channels/test_redis_channel.py +++ /dev/null @@ -1,315 +0,0 @@ -"""Tests for Redis command channel implementation.""" - -import json -from unittest.mock import MagicMock - -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - GraphEngineCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from dify_graph.variables import IntegerVariable, StringVariable - - -class TestRedisChannel: - """Test suite for RedisChannel functionality.""" - - def test_init(self): - """Test RedisChannel initialization.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - ttl = 7200 - - channel = RedisChannel(mock_redis, channel_key, ttl) - - assert channel._redis == mock_redis - assert channel._key == channel_key - assert channel._command_ttl == ttl - - def test_init_default_ttl(self): - """Test RedisChannel initialization with default TTL.""" - mock_redis = MagicMock() - channel_key = "test:channel:key" - - channel = RedisChannel(mock_redis, channel_key) - - assert channel._command_ttl == 3600 # Default TTL - - def test_send_command(self): - """Test sending a command to Redis.""" - mock_redis = MagicMock() - mock_pipe = MagicMock() - context = MagicMock() - context.__enter__.return_value = mock_pipe - context.__exit__.return_value = None - mock_redis.pipeline.return_value = context - - channel = RedisChannel(mock_redis, "test:key", 3600) - - pending_key = "test:key:pending" - - # Create a test command - command = GraphEngineCommand(command_type=CommandType.ABORT) - - # Send the command - channel.send_command(command) - - # Verify pipeline was used - mock_redis.pipeline.assert_called_once() - - # Verify rpush was called with correct data - expected_json = json.dumps(command.model_dump()) - mock_pipe.rpush.assert_called_once_with("test:key", expected_json) - - # Verify expire was set - mock_pipe.expire.assert_called_once_with("test:key", 3600) - mock_pipe.set.assert_called_once_with(pending_key, "1", ex=3600) - - # Verify execute was called - mock_pipe.execute.assert_called_once() - - def test_fetch_commands_empty(self): - """Test fetching commands when Redis list is empty.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context] - - # No pending marker - pending_pipe.execute.return_value = [None, 0] - mock_redis.llen.return_value = 0 - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.pipeline.assert_called_once() - fetch_pipe.lrange.assert_not_called() - fetch_pipe.delete.assert_not_called() - - def test_fetch_commands_with_abort_command(self): - """Test fetching abort commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create abort command data - abort_command = AbortCommand() - command_json = json.dumps(abort_command.model_dump()) - - # Simulate Redis returning one command - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - - def test_fetch_commands_multiple(self): - """Test fetching multiple commands from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Create multiple commands - command1 = GraphEngineCommand(command_type=CommandType.ABORT) - command2 = AbortCommand() - - command1_json = json.dumps(command1.model_dump()) - command2_json = json.dumps(command2.model_dump()) - - # Simulate Redis returning multiple commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command1_json.encode(), command2_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 2 - assert commands[0].command_type == CommandType.ABORT - assert isinstance(commands[1], AbortCommand) - - def test_fetch_commands_with_update_variables_command(self): - """Test fetching update variables command from Redis.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="bar", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="baz", value=123, selector=["node2", "baz"]), - ), - ] - ) - command_json = json.dumps(update_command.model_dump()) - - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert len(commands) == 1 - assert isinstance(commands[0], UpdateVariablesCommand) - assert isinstance(commands[0].updates[0].value, StringVariable) - assert list(commands[0].updates[0].value.selector) == ["node1", "foo"] - assert commands[0].updates[0].value.value == "bar" - - def test_fetch_commands_skips_invalid_json(self): - """Test that invalid JSON commands are skipped.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mix valid and invalid JSON - valid_command = AbortCommand() - valid_json = json.dumps(valid_command.model_dump()) - invalid_json = b"invalid json {" - - # Simulate Redis returning mixed valid/invalid commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[invalid_json, valid_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - # Should only return the valid command - assert len(commands) == 1 - assert isinstance(commands[0], AbortCommand) - - def test_deserialize_command_abort(self): - """Test deserializing an abort command.""" - channel = RedisChannel(MagicMock(), "test:key") - - abort_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(abort_data) - - assert isinstance(command, AbortCommand) - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_generic(self): - """Test deserializing a generic command.""" - channel = RedisChannel(MagicMock(), "test:key") - - # For now, only ABORT is supported, but test generic handling - generic_data = {"command_type": CommandType.ABORT} - command = channel._deserialize_command(generic_data) - - assert command is not None - assert command.command_type == CommandType.ABORT - - def test_deserialize_command_invalid(self): - """Test deserializing invalid command data.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Missing command_type - invalid_data = {"some_field": "value"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_deserialize_command_invalid_type(self): - """Test deserializing command with invalid type.""" - channel = RedisChannel(MagicMock(), "test:key") - - # Invalid command type - invalid_data = {"command_type": "INVALID_TYPE"} - command = channel._deserialize_command(invalid_data) - - assert command is None - - def test_atomic_fetch_and_clear(self): - """Test that fetch_commands atomically fetches and clears the list.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - command = AbortCommand() - command_json = json.dumps(command.model_dump()) - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [[command_json.encode()], 1] - - channel = RedisChannel(mock_redis, "test:key") - - # First fetch should return the command - commands = channel.fetch_commands() - assert len(commands) == 1 - - # Verify both lrange and delete were called in the pipeline - assert fetch_pipe.lrange.call_count == 1 - assert fetch_pipe.delete.call_count == 1 - fetch_pipe.lrange.assert_called_with("test:key", 0, -1) - fetch_pipe.delete.assert_called_with("test:key") - - def test_fetch_commands_without_pending_marker_returns_empty(self): - """Ensure we avoid unnecessary list reads when pending flag is missing.""" - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Pending flag absent - pending_pipe.execute.return_value = [None, 0] - channel = RedisChannel(mock_redis, "test:key") - commands = channel.fetch_commands() - - assert commands == [] - mock_redis.llen.assert_not_called() - assert mock_redis.pipeline.call_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py deleted file mode 100644 index 6f821ba799..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_handlers.py +++ /dev/null @@ -1,119 +0,0 @@ -"""Tests for graph engine event handlers.""" - -from __future__ import annotations - -from dify_graph.entities.base_node_data import RetryConfig -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine.domain.graph_execution import GraphExecution -from dify_graph.graph_engine.event_management.event_handlers import EventHandler -from dify_graph.graph_engine.event_management.event_manager import EventManager -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.ready_queue.in_memory import InMemoryReadyQueue -from dify_graph.graph_engine.response_coordinator.coordinator import ResponseStreamCoordinator -from dify_graph.graph_events import NodeRunRetryEvent, NodeRunStartedEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.runtime import GraphRuntimeState, VariablePool -from libs.datetime_utils import naive_utc_now - - -class _StubEdgeProcessor: - """Minimal edge processor stub for tests.""" - - -class _StubErrorHandler: - """Minimal error handler stub for tests.""" - - -class _StubNode: - """Simple node stub exposing the attributes needed by the state manager.""" - - def __init__(self, node_id: str) -> None: - self.id = node_id - self.state = NodeState.UNKNOWN - self.title = "Stub Node" - self.execution_type = NodeExecutionType.EXECUTABLE - self.error_strategy = None - self.retry_config = RetryConfig() - self.retry = False - - -def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]: - """Construct an EventHandler with in-memory dependencies for testing.""" - - node = _StubNode(node_id) - graph = Graph(nodes={node_id: node}, edges={}, in_edges={}, out_edges={}, root_node=node) - - variable_pool = VariablePool() - runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) - graph_execution = GraphExecution(workflow_id="test-workflow") - - event_manager = EventManager() - state_manager = GraphStateManager(graph=graph, ready_queue=InMemoryReadyQueue()) - response_coordinator = ResponseStreamCoordinator(variable_pool=variable_pool, graph=graph) - - handler = EventHandler( - graph=graph, - graph_runtime_state=runtime_state, - graph_execution=graph_execution, - response_coordinator=response_coordinator, - event_collector=event_manager, - edge_processor=_StubEdgeProcessor(), - state_manager=state_manager, - error_handler=_StubErrorHandler(), - ) - - return handler, event_manager, graph_execution - - -def test_retry_does_not_emit_additional_start_event() -> None: - """Ensure retry attempts do not produce duplicate start events.""" - - node_id = "test-node" - handler, event_manager, graph_execution = _build_event_handler(node_id) - - execution_id = "exec-1" - node_type = BuiltinNodeTypes.CODE - start_time = naive_utc_now() - - start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(start_event) - - retry_event = NodeRunRetryEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - error="boom", - retry_index=1, - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.FAILED, - error="boom", - error_type="TestError", - ), - ) - handler.dispatch(retry_event) - - # Simulate the node starting execution again after retry - second_start_event = NodeRunStartedEvent( - id=execution_id, - node_id=node_id, - node_type=node_type, - node_title="Stub Node", - start_at=start_time, - ) - handler.dispatch(second_start_event) - - collected_types = [type(event) for event in event_manager._events] # type: ignore[attr-defined] - - assert collected_types == [NodeRunStartedEvent, NodeRunRetryEvent] - - node_execution = graph_execution.get_or_create_node_execution(node_id) - assert node_execution.retry_count == 1 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py b/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py deleted file mode 100644 index 25494dc647..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/event_management/test_event_manager.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Tests for the EventManager.""" - -from __future__ import annotations - -import logging - -from dify_graph.graph_engine.event_management.event_manager import EventManager -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import GraphEngineEvent - - -class _FaultyLayer(GraphEngineLayer): - """Layer that raises from on_event to test error handling.""" - - def on_graph_start(self) -> None: # pragma: no cover - not used in tests - pass - - def on_event(self, event: GraphEngineEvent) -> None: - raise RuntimeError("boom") - - def on_graph_end(self, error: Exception | None) -> None: # pragma: no cover - not used in tests - pass - - -def test_event_manager_logs_layer_errors(caplog) -> None: - """Ensure errors raised by layers are logged when collecting events.""" - - event_manager = EventManager() - event_manager.set_layers([_FaultyLayer()]) - - with caplog.at_level(logging.ERROR): - event_manager.collect(GraphEngineEvent()) - - error_logs = [record for record in caplog.records if "Error in layer on_event" in record.getMessage()] - assert error_logs, "Expected layer errors to be logged" - - log_record = error_logs[0] - assert log_record.exc_info is not None - assert isinstance(log_record.exc_info[1], RuntimeError) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py deleted file mode 100644 index cf8811dc2b..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for graph traversal components.""" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py b/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py deleted file mode 100644 index 73d59ea4e9..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/graph_traversal/test_skip_propagator.py +++ /dev/null @@ -1,307 +0,0 @@ -"""Unit tests for skip propagator.""" - -from unittest.mock import MagicMock, create_autospec - -from dify_graph.graph import Edge, Graph -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.graph_traversal.skip_propagator import SkipPropagator - - -class TestSkipPropagator: - """Test suite for SkipPropagator.""" - - def test_propagate_skip_from_edge_with_unknown_edges_stops_processing(self) -> None: - """When there are unknown incoming edges, propagation should stop.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - # Setup graph edges dict - mock_graph.edges = {"edge_1": mock_edge} - - # Setup incoming edges - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_unknown=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_graph.get_incoming_edges.assert_called_once_with("node_2") - mock_state_manager.analyze_edge_states.assert_called_once_with(incoming_edges) - # Should not call any other state manager methods - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_taken_edge_enqueues_node(self) -> None: - """When there is at least one taken edge, node should be enqueued.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return has_taken=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - mock_state_manager.mark_node_skipped.assert_not_called() - - def test_propagate_skip_from_edge_with_all_skipped_propagates_to_node(self) -> None: - """When all incoming edges are skipped, should propagate skip to node.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create a mock edge - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Setup state manager to return all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.start_execution.assert_not_called() - - def test_propagate_skip_to_node_marks_node_and_outgoing_edges_skipped(self) -> None: - """_propagate_skip_to_node should mark node and all outgoing edges as skipped.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create outgoing edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_2" - edge1.head = "node_downstream_1" # Set head for propagate_skip_from_edge - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_3" - edge2.head = "node_downstream_2" - - # Setup graph edges dict for propagate_skip_from_edge - mock_graph.edges = {"edge_2": edge1, "edge_3": edge2} - mock_graph.get_outgoing_edges.return_value = [edge1, edge2] - - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Use mock to call private method - # Act - propagator._propagate_skip_to_node("node_1") - - # Assert - mock_state_manager.mark_node_skipped.assert_called_once_with("node_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # Should recursively propagate from each edge - # Since propagate_skip_from_edge is called, we need to verify it was called - # But we can't directly verify due to recursion. We'll trust the logic. - - def test_skip_branch_paths_marks_unselected_edges_and_propagates(self) -> None: - """skip_branch_paths should mark all unselected edges as skipped and propagate.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create unselected edges - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_downstream_1" - - edge2 = MagicMock(spec=Edge) - edge2.id = "edge_2" - edge2.head = "node_downstream_2" - - unselected_edges = [edge1, edge2] - - # Setup graph edges dict - mock_graph.edges = {"edge_1": edge1, "edge_2": edge2} - # Setup get_incoming_edges to return empty list to stop recursion - mock_graph.get_incoming_edges.return_value = [] - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.skip_branch_paths(unselected_edges) - - # Assert - mock_state_manager.mark_edge_skipped.assert_any_call("edge_1") - mock_state_manager.mark_edge_skipped.assert_any_call("edge_2") - assert mock_state_manager.mark_edge_skipped.call_count == 2 - # propagate_skip_from_edge should be called for each edge - # We can't directly verify due to the mock, but the logic is covered - - def test_propagate_skip_from_edge_recursively_propagates_through_graph(self) -> None: - """Skip propagation should recursively propagate through the graph.""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - # Create edge chain: edge_1 -> node_2 -> edge_3 -> node_4 - edge1 = MagicMock(spec=Edge) - edge1.id = "edge_1" - edge1.head = "node_2" - - edge3 = MagicMock(spec=Edge) - edge3.id = "edge_3" - edge3.head = "node_4" - - mock_graph.edges = {"edge_1": edge1, "edge_3": edge3} - - # Setup get_incoming_edges to return different values based on node - def get_incoming_edges_side_effect(node_id): - if node_id == "node_2": - return [edge1] - elif node_id == "node_4": - return [edge3] - return [] - - mock_graph.get_incoming_edges.side_effect = get_incoming_edges_side_effect - - # Setup get_outgoing_edges to return different values based on node - def get_outgoing_edges_side_effect(node_id): - if node_id == "node_2": - return [edge3] - elif node_id == "node_4": - return [] # No outgoing edges, stops recursion - return [] - - mock_graph.get_outgoing_edges.side_effect = get_outgoing_edges_side_effect - - # Setup state manager to return all_skipped for both nodes - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - # Should mark node_2 as skipped - mock_state_manager.mark_node_skipped.assert_any_call("node_2") - # Should mark edge_3 as skipped - mock_state_manager.mark_edge_skipped.assert_any_call("edge_3") - # Should propagate to node_4 - mock_state_manager.mark_node_skipped.assert_any_call("node_4") - assert mock_state_manager.mark_node_skipped.call_count == 2 - - def test_propagate_skip_from_edge_with_mixed_edge_states_handles_correctly(self) -> None: - """Test with mixed edge states (some unknown, some taken, some skipped).""" - # Arrange - mock_graph = create_autospec(Graph) - mock_state_manager = create_autospec(GraphStateManager) - - mock_edge = MagicMock(spec=Edge) - mock_edge.id = "edge_1" - mock_edge.head = "node_2" - - mock_graph.edges = {"edge_1": mock_edge} - incoming_edges = [MagicMock(spec=Edge), MagicMock(spec=Edge), MagicMock(spec=Edge)] - mock_graph.get_incoming_edges.return_value = incoming_edges - - # Test 1: has_unknown=True, has_taken=False, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": True, - "has_taken": False, - "all_skipped": False, - } - - propagator = SkipPropagator(mock_graph, mock_state_manager) - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should stop processing - mock_state_manager.enqueue_node.assert_not_called() - mock_state_manager.mark_node_skipped.assert_not_called() - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 2: has_unknown=False, has_taken=True, all_skipped=False - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": True, - "all_skipped": False, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should enqueue node - mock_state_manager.enqueue_node.assert_called_once_with("node_2") - mock_state_manager.start_execution.assert_called_once_with("node_2") - - # Reset mocks for next test - mock_state_manager.reset_mock() - mock_graph.reset_mock() - - # Test 3: has_unknown=False, has_taken=False, all_skipped=True - mock_state_manager.analyze_edge_states.return_value = { - "has_unknown": False, - "has_taken": False, - "all_skipped": True, - } - - # Act - propagator.propagate_skip_from_edge("edge_1") - - # Assert - should propagate skip - mock_state_manager.mark_node_skipped.assert_called_once_with("node_2") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py deleted file mode 100644 index fc8133f5e1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/human_input_test_utils.py +++ /dev/null @@ -1,131 +0,0 @@ -"""Utilities for testing HumanInputNode without database dependencies.""" - -from __future__ import annotations - -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRecipientEntity, - HumanInputFormRepository, -) -from libs.datetime_utils import naive_utc_now - - -class _InMemoryFormRecipient(HumanInputFormRecipientEntity): - """Minimal recipient entity required by the repository interface.""" - - def __init__(self, recipient_id: str, token: str) -> None: - self._id = recipient_id - self._token = token - - @property - def id(self) -> str: - return self._id - - @property - def token(self) -> str: - return self._token - - -@dataclass -class _InMemoryFormEntity(HumanInputFormEntity): - form_id: str - rendered: str - token: str | None = None - action_id: str | None = None - data: Mapping[str, Any] | None = None - is_submitted: bool = False - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return self.token - - @property - def recipients(self) -> list[HumanInputFormRecipientEntity]: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class InMemoryHumanInputFormRepository(HumanInputFormRepository): - """Pure in-memory repository used by workflow graph engine tests.""" - - def __init__(self) -> None: - self._form_counter = 0 - self.created_params: list[FormCreateParams] = [] - self.created_forms: list[_InMemoryFormEntity] = [] - self._forms_by_key: dict[tuple[str, str], _InMemoryFormEntity] = {} - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - self.created_params.append(params) - self._form_counter += 1 - form_id = f"form-{self._form_counter}" - token = f"console-{form_id}" if params.console_recipient_required else f"token-{form_id}" - entity = _InMemoryFormEntity( - form_id=form_id, - rendered=params.rendered_content, - token=token, - ) - self.created_forms.append(entity) - self._forms_by_key[(params.workflow_execution_id, params.node_id)] = entity - return entity - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_key.get((workflow_execution_id, node_id)) - - # Convenience helpers for tests ------------------------------------- - - def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: - """Simulate a human submission for the next repository lookup.""" - - if not self.created_forms: - raise AssertionError("no form has been created to attach submission data") - entity = self.created_forms[-1] - entity.action_id = action_id - entity.data = form_data or {} - entity.is_submitted = True - entity.status_value = HumanInputFormStatus.SUBMITTED - entity.expiration = naive_utc_now() + timedelta(days=1) - - def clear_submission(self) -> None: - if not self.created_forms: - return - for form in self.created_forms: - form.action_id = None - form.data = None - form.is_submitted = False - form.status_value = HumanInputFormStatus.WAITING diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py index 9e7b3654b7..41627f5e0b 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/conftest.py @@ -5,13 +5,12 @@ Shared fixtures for ObservabilityLayer tests. from unittest.mock import MagicMock, patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from opentelemetry.trace import set_tracer_provider -from dify_graph.enums import BuiltinNodeTypes - @pytest.fixture def memory_span_exporter(): @@ -62,8 +61,9 @@ def mock_llm_node(): @pytest.fixture def mock_tool_node(): """Create a mock Tool Node with tool-specific attributes.""" + from graphon.nodes.tool.entities import ToolNodeData + from core.tools.entities.tool_entities import ToolProviderType - from dify_graph.nodes.tool.entities import ToolNodeData node = MagicMock() node.id = "test-tool-node-id" @@ -117,8 +117,8 @@ def mock_result_event(): """Create a mock result event with NodeRunResult.""" from datetime import datetime - from dify_graph.graph_events.node import NodeRunSucceededEvent - from dify_graph.node_events.base import NodeRunResult + from graphon.graph_events import NodeRunSucceededEvent + from graphon.node_events import NodeRunResult node_run_result = NodeRunResult( inputs={"query": "test query"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py deleted file mode 100644 index db32527849..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_layer_initialization.py +++ /dev/null @@ -1,57 +0,0 @@ -from __future__ import annotations - -import pytest - -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import ( - GraphEngineLayer, - GraphEngineLayerNotInitializedError, -) -from dify_graph.graph_events import GraphEngineEvent - -from ..test_table_runner import WorkflowRunner - - -class LayerForTest(GraphEngineLayer): - def on_graph_start(self) -> None: - pass - - def on_event(self, event: GraphEngineEvent) -> None: - pass - - def on_graph_end(self, error: Exception | None) -> None: - pass - - -def test_layer_runtime_state_raises_when_uninitialized() -> None: - layer = LayerForTest() - - with pytest.raises(GraphEngineLayerNotInitializedError): - _ = layer.graph_runtime_state - - -def test_layer_runtime_state_available_after_engine_layer() -> None: - runner = WorkflowRunner() - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture( - fixture_data, - inputs={"query": "test layer state"}, - ) - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - layer = LayerForTest() - engine.layer(layer) - - outputs = layer.graph_runtime_state.outputs - ready_queue_size = layer.graph_runtime_state.ready_queue_size - - assert outputs == {} - assert isinstance(ready_queue_size, int) - assert ready_queue_size >= 0 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py index 2a36f712fd..99d131737e 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_llm_quota.py @@ -1,14 +1,28 @@ import threading from datetime import datetime +from types import SimpleNamespace from unittest.mock import MagicMock, patch +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.graph_engine.entities.commands import CommandType +from graphon.graph_events import NodeRunSucceededEvent +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, UserFrom from core.app.workflow.layers.llm_quota import LLMQuotaLayer from core.errors.error import QuotaExceededError -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.entities.commands import CommandType -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult +from core.model_manager import ModelInstance + + +def _build_dify_context() -> DifyRunContext: + return DifyRunContext( + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) def _build_succeeded_event() -> NodeRunSucceededEvent: @@ -25,6 +39,11 @@ def _build_succeeded_event() -> NodeRunSucceededEvent: ) +def _build_wrapped_model_instance() -> tuple[SimpleNamespace, ModelInstance]: + raw_model_instance = ModelInstance.__new__(ModelInstance) + return SimpleNamespace(_model_instance=raw_model_instance), raw_model_instance + + def test_deduct_quota_called_for_successful_llm_node() -> None: layer = LLMQuotaLayer() node = MagicMock() @@ -32,8 +51,8 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -41,7 +60,7 @@ def test_deduct_quota_called_for_successful_llm_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -53,8 +72,8 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.QUESTION_CLASSIFIER node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() result_event = _build_succeeded_event() with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct: @@ -62,7 +81,7 @@ def test_deduct_quota_called_for_question_classifier_node() -> None: mock_deduct.assert_called_once_with( tenant_id="tenant-id", - model_instance=node.model_instance, + model_instance=raw_model_instance, usage=result_event.node_run_result.llm_usage, ) @@ -74,7 +93,7 @@ def test_non_llm_node_is_ignored() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.START node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node._model_instance = object() result_event = _build_succeeded_event() @@ -91,7 +110,7 @@ def test_quota_error_is_handled_in_layer() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" + node.require_run_context_value.return_value = _build_dify_context() node.model_instance = object() result_event = _build_succeeded_event() @@ -113,8 +132,8 @@ def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None: node.execution_id = "execution-id" node.node_type = BuiltinNodeTypes.LLM node.tenant_id = "tenant-id" - node.require_dify_context.return_value.tenant_id = "tenant-id" - node.model_instance = object() + node.require_run_context_value.return_value = _build_dify_context() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -141,7 +160,7 @@ def test_quota_precheck_failure_aborts_workflow_immediately() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, _ = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -167,7 +186,7 @@ def test_quota_precheck_passes_without_abort() -> None: node = MagicMock() node.id = "llm-node-id" node.node_type = BuiltinNodeTypes.LLM - node.model_instance = object() + node.model_instance, raw_model_instance = _build_wrapped_model_instance() node.graph_runtime_state = MagicMock() node.graph_runtime_state.stop_event = stop_event @@ -175,5 +194,5 @@ def test_quota_precheck_passes_without_abort() -> None: layer.on_node_run_start(node) assert not stop_event.is_set() - mock_check.assert_called_once_with(model_instance=node.model_instance) + mock_check.assert_called_once_with(model_instance=raw_model_instance) layer.command_channel.send_command.assert_not_called() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py index 478a2b592e..9cf72763ee 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/layers/test_observability.py @@ -13,10 +13,10 @@ Test coverage: from unittest.mock import patch import pytest +from graphon.enums import BuiltinNodeTypes from opentelemetry.trace import StatusCode from core.app.workflow.layers.observability import ObservabilityLayer -from dify_graph.enums import BuiltinNodeTypes class TestObservabilityLayerInitialization: @@ -144,7 +144,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_llm_node, mock_result_event ): """Test that LLM parser is used for LLM nodes and extracts LLM-specific attributes.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={}, @@ -182,7 +182,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_retrieval_node, mock_result_event ): """Test that retrieval parser is used for retrieval nodes and extracts retrieval-specific attributes.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"query": "test query"}, @@ -210,7 +210,7 @@ class TestObservabilityLayerParserIntegration: self, tracer_provider_with_memory_exporter, memory_span_exporter, mock_start_node, mock_result_event ): """Test that result_event parameter allows parsers to extract inputs and outputs.""" - from dify_graph.node_events.base import NodeRunResult + from graphon.node_events import NodeRunResult mock_result_event.node_run_result = NodeRunResult( inputs={"input_key": "input_value"}, diff --git a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py b/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py deleted file mode 100644 index 548c10ce8d..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/orchestration/test_dispatcher.py +++ /dev/null @@ -1,189 +0,0 @@ -"""Tests for dispatcher command checking behavior.""" - -from __future__ import annotations - -import queue -from unittest import mock - -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.event_management.event_handlers import EventHandler -from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher -from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from dify_graph.graph_events import ( - GraphNodeEventBase, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from dify_graph.node_events import NodeRunResult -from libs.datetime_utils import naive_utc_now - - -def test_dispatcher_should_consume_remains_events_after_pause(): - event_queue = queue.Queue() - event_queue.put( - GraphNodeEventBase( - id="test", - node_id="test", - node_type=BuiltinNodeTypes.START, - ) - ) - event_handler = mock.Mock(spec=EventHandler) - execution_coordinator = mock.Mock(spec=ExecutionCoordinator) - execution_coordinator.paused.return_value = True - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=execution_coordinator, - ) - dispatcher._dispatcher_loop() - assert event_queue.empty() - - -class _StubExecutionCoordinator: - """Stub execution coordinator that tracks command checks.""" - - def __init__(self) -> None: - self.command_checks = 0 - self.scaling_checks = 0 - self.execution_complete = False - self.failed = False - self._paused = False - - def process_commands(self) -> None: - self.command_checks += 1 - - def check_scaling(self) -> None: - self.scaling_checks += 1 - - @property - def paused(self) -> bool: - return self._paused - - @property - def aborted(self) -> bool: - return False - - def mark_complete(self) -> None: - self.execution_complete = True - - def mark_failed(self, error: Exception) -> None: # pragma: no cover - defensive, not triggered in tests - self.failed = True - - -class _StubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - self._coordinator.mark_complete() - - -def _run_dispatcher_for_event(event) -> int: - """Run the dispatcher loop for a single event and return command check count.""" - event_queue: queue.Queue = queue.Queue() - event_queue.put(event) - - coordinator = _StubExecutionCoordinator() - event_handler = _StubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - return coordinator.command_checks - - -def _make_started_event() -> NodeRunStartedEvent: - return NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - ) - - -def _make_succeeded_event() -> NodeRunSucceededEvent: - return NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Test Node", - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - - -def test_dispatcher_checks_commands_during_idle_and_on_completion() -> None: - """Dispatcher polls commands when idle and after completion events.""" - started_checks = _run_dispatcher_for_event(_make_started_event()) - succeeded_checks = _run_dispatcher_for_event(_make_succeeded_event()) - - assert started_checks == 2 - assert succeeded_checks == 3 - - -class _PauseStubEventHandler: - """Minimal event handler that marks execution complete after handling an event.""" - - def __init__(self, coordinator: _StubExecutionCoordinator) -> None: - self._coordinator = coordinator - self.events = [] - - def dispatch(self, event) -> None: - self.events.append(event) - if isinstance(event, NodeRunPauseRequestedEvent): - self._coordinator.mark_complete() - - -def test_dispatcher_drain_event_queue(): - events = [ - NodeRunStartedEvent( - id="start-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - node_title="Code", - start_at=naive_utc_now(), - ), - NodeRunPauseRequestedEvent( - id="pause-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - reason=SchedulingPause(message="test pause"), - ), - NodeRunSucceededEvent( - id="success-event", - node_id="node-1", - node_type=BuiltinNodeTypes.CODE, - start_at=naive_utc_now(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ), - ] - - event_queue: queue.Queue = queue.Queue() - for e in events: - event_queue.put(e) - - coordinator = _StubExecutionCoordinator() - event_handler = _PauseStubEventHandler(coordinator) - - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=event_handler, - execution_coordinator=coordinator, - ) - - dispatcher._dispatcher_loop() - - # ensure all events are drained. - assert event_queue.empty() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py deleted file mode 100644 index 7af6b26d87..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_end_with_text.py +++ /dev/null @@ -1,37 +0,0 @@ -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_answer_end_with_text(): - fixture_name = "answer_end_with_text" - case = WorkflowTestCase( - fixture_name, - query="Hello, AI!", - expected_outputs={"answer": "prefixHello, AI!suffix"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - # The chunks are now emitted as the Answer node processes them - # since sys.query is a special selector that gets attributed to - # the active response node - NodeRunStreamChunkEvent, # prefix - NodeRunStreamChunkEvent, # sys.query - NodeRunStreamChunkEvent, # suffix - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py deleted file mode 100644 index 6569439b56..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_answer_order_workflow.py +++ /dev/null @@ -1,28 +0,0 @@ -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - -LLM_NODE_ID = "1759052580454" - - -def test_answer_nodes_emit_in_order() -> None: - mock_config = ( - MockConfigBuilder() - .with_llm_response("unused default") - .with_node_output(LLM_NODE_ID, {"text": "mocked llm text"}) - .build() - ) - - expected_answer = "--- answer 1 ---\n\nfoo\n--- answer 2 ---\n\nmocked llm text\n" - - case = WorkflowTestCase( - fixture_path="test-answer-order", - query="", - expected_outputs={"answer": expected_answer}, - use_auto_mock=True, - mock_config=mock_config, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, result.error diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py deleted file mode 100644 index 05ec565def..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_array_iteration_formatting_workflow.py +++ /dev/null @@ -1,24 +0,0 @@ -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_array_iteration_formatting_workflow(): - """ - Validate Iteration node processes [1,2,3] into formatted strings. - - Fixture description expects: - {"output": ["output: 1", "output: 2", "output: 3"]} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="array_iteration_formatting_workflow", - inputs={}, - expected_outputs={"output": ["output: 1", "output: 2", "output: 3"]}, - description="Iteration formats numbers into strings", - use_auto_mock=True, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Iteration workflow failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py deleted file mode 100644 index fc0d22f739..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_auto_mock_system.py +++ /dev/null @@ -1,392 +0,0 @@ -""" -Tests for the auto-mock system. - -This module contains tests that validate the auto-mock functionality -for workflows containing nodes that require third-party services. -""" - -import pytest - -from dify_graph.enums import BuiltinNodeTypes -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_simple_llm_workflow_with_auto_mock(): - """Test that a simple LLM workflow runs successfully with auto-mocking.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfigBuilder().with_llm_response("This is a test response from mocked LLM").build() - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Hello, how are you?"}, - expected_outputs={"answer": "This is a test response from mocked LLM"}, - description="Simple LLM workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert "answer" in result.actual_outputs - assert result.actual_outputs["answer"] == "This is a test response from mocked LLM" - - -def test_llm_workflow_with_custom_node_output(): - """Test LLM workflow with custom output for specific node.""" - runner = TableTestRunner() - - # Create mock configuration with custom output for specific node - mock_config = MockConfig() - mock_config.set_node_outputs( - "llm_node", - { - "text": "Custom response for this specific node", - "usage": { - "prompt_tokens": 20, - "completion_tokens": 10, - "total_tokens": 30, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test query"}, - expected_outputs={"answer": "Custom response for this specific node"}, - description="LLM workflow with custom node output", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["answer"] == "Custom response for this specific node" - - -def test_http_tool_workflow_with_auto_mock(): - """Test workflow with HTTP request and tool nodes using auto-mock.""" - runner = TableTestRunner() - - # Create mock configuration - mock_config = MockConfig() - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"key": "value", "number": 42}', - "headers": {"content-type": "application/json"}, - }, - ) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"key": "value", "number": 42}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http_request_with_json_tool_workflow", - inputs={"url": "https://api.example.com/data"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"key": "value", "number": 42}, - }, - description="HTTP and Tool workflow with auto-mock", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs["status_code"] == 200 - assert result.actual_outputs["parsed_data"] == {"key": "value", "number": 42} - - -def test_workflow_with_simulated_node_error(): - """Test that workflows handle simulated node errors correctly.""" - runner = TableTestRunner() - - # Create mock configuration with error - mock_config = MockConfig() - mock_config.set_node_error("llm_node", "Simulated LLM API error") - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "This should fail"}, - expected_outputs={}, # We expect failure, so no outputs - description="LLM workflow with simulated error", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - # The workflow should fail due to the simulated error - assert not result.success - assert result.error is not None - - -def test_workflow_with_mock_delays(): - """Test that mock delays work correctly.""" - runner = TableTestRunner() - - # Create mock configuration with delays - mock_config = MockConfig(simulate_delays=True) - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.1, # 100ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="LLM workflow with simulated delay", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - # Execution time should be at least the delay - assert result.execution_time >= 0.1 - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - -def test_mock_factory_node_type_detection(): - """Test that MockNodeFactory correctly identifies nodes to mock.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - -def test_custom_mock_handler(): - """Test using a custom handler function for mock outputs.""" - runner = TableTestRunner() - - # Custom handler that modifies output based on input - def custom_llm_handler(node) -> dict: - # In a real scenario, we could access node.graph_runtime_state.variable_pool - # to get the actual inputs - return { - "text": "Custom handler response", - "usage": { - "prompt_tokens": 5, - "completion_tokens": 3, - "total_tokens": 8, - }, - "finish_reason": "stop", - } - - mock_config = MockConfig() - node_config = NodeMockConfig( - node_id="llm_node", - custom_handler=custom_llm_handler, - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="basic_llm_chat_workflow", - inputs={"query": "Test custom handler"}, - expected_outputs={"answer": "Custom handler response"}, - description="LLM workflow with custom handler", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["answer"] == "Custom handler response" - - -def test_workflow_without_auto_mock(): - """Test that workflows work normally without auto-mock enabled.""" - runner = TableTestRunner() - - # This test uses the echo workflow which doesn't need external services - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "Test without mock"}, - expected_outputs={"query": "Test without mock"}, - description="Echo workflow without auto-mock", - use_auto_mock=False, # Auto-mock disabled - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed: {result.error}" - assert result.actual_outputs["query"] == "Test without mock" - - -def test_register_custom_mock_node(): - """Test registering a custom mock implementation for a node type.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.nodes.template_transform import TemplateTransformNode - from dify_graph.runtime import GraphRuntimeState, VariablePool - - from .test_mock_factory import MockNodeFactory - - # Create a custom mock for TemplateTransformNode - class MockTemplateTransformNode(TemplateTransformNode): - def _run(self): - # Custom mock implementation - pass - - graph_init_params = build_test_graph_init_params( - workflow_id="test", - graph_config={}, - tenant_id="test", - app_id="test", - user_id="test", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.SERVICE_API, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Re-register custom mock - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - -def test_default_config_by_node_type(): - """Test setting default configurations by node type.""" - mock_config = MockConfig() - - # Set default config for all LLM nodes - mock_config.set_default_config( - BuiltinNodeTypes.LLM, - { - "default_response": "Default LLM response for all nodes", - "temperature": 0.7, - }, - ) - - # Set default config for all HTTP nodes - mock_config.set_default_config( - BuiltinNodeTypes.HTTP_REQUEST, - { - "default_status": 200, - "default_timeout": 30, - }, - ) - - llm_config = mock_config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config["default_response"] == "Default LLM response for all nodes" - assert llm_config["temperature"] == 0.7 - - http_config = mock_config.get_default_config(BuiltinNodeTypes.HTTP_REQUEST) - assert http_config["default_status"] == 200 - assert http_config["default_timeout"] == 30 - - # Non-configured node type should return empty dict - tool_config = mock_config.get_default_config(BuiltinNodeTypes.TOOL) - assert tool_config == {} - - -if __name__ == "__main__": - # Run all tests - pytest.main([__file__, "-v"]) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py deleted file mode 100644 index 30acbdaf3d..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_basic_chatflow.py +++ /dev/null @@ -1,41 +0,0 @@ -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_basic_chatflow(): - fixture_name = "basic_chatflow" - mock_config = MockConfigBuilder().with_llm_response("mocked llm response").build() - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={"answer": "mocked llm response"}, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LLM - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * ("mocked llm response".count(" ") + 2) - + [ - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py b/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py deleted file mode 100644 index 765c4deba3..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_command_system.py +++ /dev/null @@ -1,267 +0,0 @@ -"""Test the command system for GraphEngine control.""" - -import time -from unittest.mock import MagicMock - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.pause_reason import SchedulingPause -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.entities.commands import ( - AbortCommand, - CommandType, - PauseCommand, - UpdateVariablesCommand, - VariableUpdate, -) -from dify_graph.graph_events import GraphRunAbortedEvent, GraphRunPausedEvent, GraphRunStartedEvent -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.variables import IntegerVariable, StringVariable - - -def test_abort_command(): - """Test that GraphEngine properly handles abort commands.""" - - # Create shared GraphRuntimeState - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - # Create a minimal mock graph - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - # Create mock nodes with required attributes - using shared runtime state - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - # Mock graph methods - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - # Create command channel - command_channel = InMemoryChannel() - - # Create GraphEngine with same shared runtime state - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, # Use shared instance - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - # Send abort command before starting - abort_command = AbortCommand(reason="Test abort") - command_channel.send_command(abort_command) - - # Run engine and collect events - events = list(engine.run()) - - # Verify we get start and abort events - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - assert any(isinstance(e, GraphRunAbortedEvent) for e in events) - - # Find the abort event and check its reason - abort_events = [e for e in events if isinstance(e, GraphRunAbortedEvent)] - assert len(abort_events) == 1 - assert abort_events[0].reason is not None - assert "aborted: test abort" in abort_events[0].reason.lower() - - -def test_redis_channel_serialization(): - """Test that Redis channel properly serializes and deserializes commands.""" - import json - from unittest.mock import MagicMock - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = MagicMock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = MagicMock(return_value=None) - - from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel - - # Create channel with a specific key - channel = RedisChannel(mock_redis, channel_key="workflow:123:commands") - - # Test sending a command - abort_command = AbortCommand(reason="Test abort") - channel.send_command(abort_command) - - # Verify redis methods were called - mock_pipeline.rpush.assert_called_once() - mock_pipeline.expire.assert_called_once() - - # Verify the serialized data - call_args = mock_pipeline.rpush.call_args - key = call_args[0][0] - command_json = call_args[0][1] - - assert key == "workflow:123:commands" - - # Verify JSON structure - command_data = json.loads(command_json) - assert command_data["command_type"] == "abort" - assert command_data["reason"] == "Test abort" - - # Test pause command serialization - pause_command = PauseCommand(reason="User requested pause") - channel.send_command(pause_command) - - assert len(mock_pipeline.rpush.call_args_list) == 2 - second_call_args = mock_pipeline.rpush.call_args_list[1] - pause_command_json = second_call_args[0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - -def test_pause_command(): - """Test that GraphEngine properly handles pause commands.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - pause_command = PauseCommand(reason="User requested pause") - command_channel.send_command(pause_command) - - events = list(engine.run()) - - assert any(isinstance(e, GraphRunStartedEvent) for e in events) - pause_events = [e for e in events if isinstance(e, GraphRunPausedEvent)] - assert len(pause_events) == 1 - assert pause_events[0].reasons == [SchedulingPause(message="User requested pause")] - - graph_execution = engine.graph_runtime_state.graph_execution - assert graph_execution.pause_reasons == [SchedulingPause(message="User requested pause")] - - -def test_update_variables_command_updates_pool(): - """Test that GraphEngine updates variable pool via update variables command.""" - - shared_runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter()) - shared_runtime_state.variable_pool.add(("node1", "foo"), "old value") - - mock_graph = MagicMock(spec=Graph) - mock_graph.nodes = {} - mock_graph.edges = {} - mock_graph.root_node = MagicMock() - mock_graph.root_node.id = "start" - - start_node = StartNode( - id="start", - config={"id": "start", "data": {"title": "start", "variables": []}}, - graph_init_params=GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ), - graph_runtime_state=shared_runtime_state, - ) - mock_graph.nodes["start"] = start_node - - mock_graph.get_outgoing_edges = MagicMock(return_value=[]) - mock_graph.get_incoming_edges = MagicMock(return_value=[]) - - command_channel = InMemoryChannel() - - engine = GraphEngine( - workflow_id="test_workflow", - graph=mock_graph, - graph_runtime_state=shared_runtime_state, - command_channel=command_channel, - config=GraphEngineConfig(), - ) - - update_command = UpdateVariablesCommand( - updates=[ - VariableUpdate( - value=StringVariable(name="foo", value="new value", selector=["node1", "foo"]), - ), - VariableUpdate( - value=IntegerVariable(name="bar", value=123, selector=["node2", "bar"]), - ), - ] - ) - command_channel.send_command(update_command) - - list(engine.run()) - - updated_existing = shared_runtime_state.variable_pool.get(["node1", "foo"]) - added_new = shared_runtime_state.variable_pool.get(["node2", "bar"]) - - assert updated_existing is not None - assert updated_existing.value == "new value" - assert added_new is not None - assert added_new.value == 123 diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py deleted file mode 100644 index 3a9a0b18bc..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_complex_branch_workflow.py +++ /dev/null @@ -1,134 +0,0 @@ -""" -Test suite for complex branch workflow with parallel execution and conditional routing. - -This test suite validates the behavior of a workflow that: -1. Executes nodes in parallel (IF/ELSE and LLM branches) -2. Routes based on conditional logic (query containing 'hello') -3. Handles multiple answer nodes with different outputs -""" - -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestComplexBranchWorkflow: - """Test suite for complex branch workflow with parallel execution.""" - - def setup_method(self): - """Set up test environment before each test method.""" - self.runner = TableTestRunner() - self.fixture_path = "test_complex_branch" - - def test_hello_branch_with_llm(self): - """ - Test when query contains 'hello' - should trigger true branch. - Both IF/ELSE and LLM should execute in parallel. - """ - mock_text_1 = "This is a mocked LLM response for hello world" - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="hello world", - expected_outputs={ - "answer": f"contains 'hello'{mock_text_1}", - }, - description="Basic hello case with parallel LLM execution", - use_auto_mock=True, - mock_config=(MockConfigBuilder().with_node_output("1755502777322", {"text": mock_text_1}).build()), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="say hello to everyone", - expected_outputs={ - "answer": "contains 'hello'Mocked response for greeting", - }, - description="Hello in middle of sentence", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked response for greeting"}) - .build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" - assert result.actual_outputs - assert any(isinstance(event, GraphRunStartedEvent) for event in result.events) - assert any(isinstance(event, GraphRunSucceededEvent) for event in result.events) - - start_index = next( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunStartedEvent) - ) - success_index = max( - idx for idx, event in enumerate(result.events) if isinstance(event, GraphRunSucceededEvent) - ) - assert start_index < success_index - - started_node_ids = {event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)} - assert {"1755502773326", "1755502777322"}.issubset(started_node_ids), ( - f"Branch or LLM nodes missing in events: {started_node_ids}" - ) - - assert any(isinstance(event, NodeRunStreamChunkEvent) for event in result.events), ( - "Expected streaming chunks from LLM execution" - ) - - llm_start_index = next( - idx - for idx, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "1755502777322" - ) - assert any( - idx > llm_start_index and isinstance(event, NodeRunStreamChunkEvent) - for idx, event in enumerate(result.events) - ), "Streaming chunks should follow LLM node start" - - def test_non_hello_branch_with_llm(self): - """ - Test when query doesn't contain 'hello' - should trigger false branch. - LLM output should be used as the final answer. - """ - test_cases = [ - WorkflowTestCase( - fixture_path=self.fixture_path, - query="goodbye world", - expected_outputs={ - "answer": "Mocked LLM response for goodbye", - }, - description="Goodbye case - false branch with LLM output", - use_auto_mock=True, - mock_config=( - MockConfigBuilder() - .with_node_output("1755502777322", {"text": "Mocked LLM response for goodbye"}) - .build() - ), - ), - WorkflowTestCase( - fixture_path=self.fixture_path, - query="test message", - expected_outputs={ - "answer": "Mocked response for test", - }, - description="Regular message - false branch", - use_auto_mock=True, - mock_config=( - MockConfigBuilder().with_node_output("1755502777322", {"text": "Mocked response for test"}).build() - ), - ), - ] - - suite_result = self.runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test '{result.test_case.description}' failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py deleted file mode 100644 index 76bf179f33..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_conditional_streaming_vs_template_workflow.py +++ /dev/null @@ -1,220 +0,0 @@ -""" -Test for streaming output workflow behavior. - -This test validates that: -- When blocking == 1: No NodeRunStreamChunkEvent (flow through Template node) -- When blocking != 1: NodeRunStreamChunkEvent present (direct LLM to End output) -""" - -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner - - -def test_streaming_output_with_blocking_equals_one(): - """ - Test workflow when blocking == 1 (LLM → Template → End). - - Template node doesn't produce streaming output, so no NodeRunStreamChunkEvent should be present. - This test should FAIL according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 1}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # According to requirements, we expect exactly 3 streaming events from the End node - # 1. User query - # 2. Newline - # 3. Template output (which contains the LLM response) - assert stream_chunk_count == 3, f"Expected 3 streaming events when blocking=1, but got {stream_chunk_count}" - - first_chunk, second_chunk, third_chunk = stream_chunk_events[0], stream_chunk_events[1], stream_chunk_events[2] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - # Third chunk will be the template output with the mock LLM response - assert isinstance(third_chunk.chunk, str), f"Expected third chunk to be string, but got {type(third_chunk.chunk)}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all Template's NodeRunStreamChunkEvent should has same id with Template's NodeRunStartedEvent - start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM - ] - template_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM] - assert len(template_chunk_events) == 1, f"Expected 1 template chunk event, but got {len(template_chunk_events)}" - assert all(e.id in [se.id for se in start_events] for e in template_chunk_events), ( - "Expected all Template chunk events to have same id with Template's NodeRunStartedEvent" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) - - -def test_streaming_output_with_blocking_not_equals_one(): - """ - Test workflow when blocking != 1 (LLM → End directly). - - End node should produce streaming output with NodeRunStreamChunkEvent. - This test should PASS according to requirements. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("conditional_streaming_vs_template_workflow") - - # Create graph from fixture with auto-mock enabled - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - inputs={"query": "Hello, how are you?", "blocking": 2}, - use_mock_factory=True, - ) - - # Create and run the engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Execute the workflow - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Check for streaming events - expecting streaming events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - stream_chunk_count = len(stream_chunk_events) - - # This assertion should PASS according to requirements - assert stream_chunk_count > 0, f"Expected streaming events when blocking!=1, but got {stream_chunk_count}" - - # We should have at least 2 chunks (query and newline) - assert stream_chunk_count >= 2, f"Expected at least 2 streaming events, but got {stream_chunk_count}" - - first_chunk, second_chunk = stream_chunk_events[0], stream_chunk_events[1] - assert first_chunk.chunk == "Hello, how are you?", ( - f"Expected first chunk to be user input, but got {first_chunk.chunk}" - ) - assert second_chunk.chunk == "\n", f"Expected second chunk to be newline, but got {second_chunk.chunk}" - - # Find indices of first LLM success event and first stream chunk event - llm2_start_index = next( - ( - i - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ), - -1, - ) - first_chunk_index = next( - (i for i, e in enumerate(events) if isinstance(e, NodeRunStreamChunkEvent)), - -1, - ) - - assert first_chunk_index < llm2_start_index, ( - f"Expected first chunk before LLM2 start, but got {first_chunk_index} and {llm2_start_index}" - ) - - # With auto-mock, the LLM will produce mock responses - just verify we have streaming chunks - # and they are strings - for chunk_event in stream_chunk_events[2:]: - assert isinstance(chunk_event.chunk, str), f"Expected chunk to be string, but got {type(chunk_event.chunk)}" - - # Check that NodeRunStreamChunkEvent contains 'query' should has same id with Start NodeRunStartedEvent - start_node_id = graph.root_node.id - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_id == start_node_id] - assert len(start_events) == 1, f"Expected 1 start event for node {start_node_id}, but got {len(start_events)}" - start_event = start_events[0] - query_chunk_events = [e for e in stream_chunk_events if e.chunk == "Hello, how are you?"] - assert all(e.id == start_event.id for e in query_chunk_events), "Expected all query chunk events to have same id" - - # Check all LLM's NodeRunStreamChunkEvent should be from LLM nodes - start_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.LLM] - llm_chunk_events = [e for e in stream_chunk_events if e.node_type == BuiltinNodeTypes.LLM] - llm_node_ids = {se.node_id for se in start_events} - assert all(e.node_id in llm_node_ids for e in llm_chunk_events), ( - "Expected all LLM chunk events to be from LLM nodes" - ) - - # Check that NodeRunStreamChunkEvent contains '\n' is from the End node - end_events = [e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.END] - assert len(end_events) == 1, f"Expected 1 end event, but got {len(end_events)}" - newline_chunk_events = [e for e in stream_chunk_events if e.chunk == "\n"] - assert len(newline_chunk_events) == 1, f"Expected 1 newline chunk event, but got {len(newline_chunk_events)}" - # The newline chunk should be from the End node (check node_id, not execution id) - assert all(e.node_id == end_events[0].node_id for e in newline_chunk_events), ( - "Expected all newline chunk events to be from End node" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py b/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py deleted file mode 100644 index ae7dd48bb1..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_database_utils.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -Utilities for detecting if database service is available for workflow tests. -""" - -import psycopg2 -import pytest - -from configs import dify_config - - -def is_database_available() -> bool: - """ - Check if the database service is available by attempting to connect to it. - - Returns: - True if database is available, False otherwise. - """ - try: - # Try to establish a database connection using a context manager - with psycopg2.connect( - host=dify_config.DB_HOST, - port=dify_config.DB_PORT, - database=dify_config.DB_DATABASE, - user=dify_config.DB_USERNAME, - password=dify_config.DB_PASSWORD, - connect_timeout=2, # 2 second timeout - ) as conn: - pass # Connection established and will be closed automatically - return True - except (psycopg2.OperationalError, psycopg2.Error): - return False - - -def skip_if_database_unavailable(): - """ - Pytest skip decorator that skips tests when database service is unavailable. - - Usage: - @skip_if_database_unavailable() - def test_my_workflow(): - ... - """ - return pytest.mark.skipif( - not is_database_available(), - reason="Database service is not available (connection refused or authentication failed)", - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py b/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py deleted file mode 100644 index 778dad5952..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_dispatcher_pause_drain.py +++ /dev/null @@ -1,72 +0,0 @@ -import queue -from datetime import datetime - -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.orchestration.dispatcher import Dispatcher -from dify_graph.graph_events import NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult - - -class StubExecutionCoordinator: - def __init__(self, paused: bool) -> None: - self._paused = paused - self.mark_complete_called = False - self.failed_error: Exception | None = None - - @property - def aborted(self) -> bool: - return False - - @property - def paused(self) -> bool: - return self._paused - - @property - def execution_complete(self) -> bool: - return False - - def check_scaling(self) -> None: - return None - - def process_commands(self) -> None: - return None - - def mark_complete(self) -> None: - self.mark_complete_called = True - - def mark_failed(self, error: Exception) -> None: - self.failed_error = error - - -class StubEventHandler: - def __init__(self) -> None: - self.events: list[object] = [] - - def dispatch(self, event: object) -> None: - self.events.append(event) - - -def test_dispatcher_drains_events_when_paused() -> None: - event_queue: queue.Queue = queue.Queue() - event = NodeRunSucceededEvent( - id="exec-1", - node_id="node-1", - node_type=BuiltinNodeTypes.START, - start_at=datetime.utcnow(), - node_run_result=NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED), - ) - event_queue.put(event) - - handler = StubEventHandler() - coordinator = StubExecutionCoordinator(paused=True) - dispatcher = Dispatcher( - event_queue=event_queue, - event_handler=handler, - execution_coordinator=coordinator, - event_emitter=None, - ) - - dispatcher._dispatcher_loop() - - assert handler.events == [event] - assert coordinator.mark_complete_called is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py b/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py deleted file mode 100644 index c87dc75b95..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_end_node_without_value_type.py +++ /dev/null @@ -1,60 +0,0 @@ -""" -Test case for end node without value_type field (backward compatibility). - -This test validates that end nodes work correctly even when the value_type -field is missing from the output configuration, ensuring backward compatibility -with older workflow definitions. -""" - -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_end_node_without_value_type_field(): - """ - Test that end node works without explicit value_type field. - - The fixture implements a simple workflow that: - 1. Takes a query input from start node - 2. Passes it directly to end node - 3. End node outputs the value without specifying value_type - 4. Should correctly infer the type and output the value - - This ensures backward compatibility with workflow definitions - created before value_type became a required field. - """ - fixture_name = "end_node_without_value_type_field_workflow" - - case = WorkflowTestCase( - fixture_path=fixture_name, - inputs={"query": "test query"}, - expected_outputs={"query": "test query"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start node - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # Start node streams the input value - NodeRunSucceededEvent, - # End node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - description="End node without value_type field should work correctly", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == {"query": "test query"}, ( - f"Expected output to be {{'query': 'test query'}}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py deleted file mode 100644 index 35406997ed..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_execution_coordinator.py +++ /dev/null @@ -1,62 +0,0 @@ -"""Unit tests for the execution coordinator orchestration logic.""" - -from unittest.mock import MagicMock - -import pytest - -from dify_graph.graph_engine.command_processing.command_processor import CommandProcessor -from dify_graph.graph_engine.domain.graph_execution import GraphExecution -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.orchestration.execution_coordinator import ExecutionCoordinator -from dify_graph.graph_engine.worker_management.worker_pool import WorkerPool - - -def _build_coordinator(graph_execution: GraphExecution) -> tuple[ExecutionCoordinator, MagicMock, MagicMock]: - command_processor = MagicMock(spec=CommandProcessor) - state_manager = MagicMock(spec=GraphStateManager) - worker_pool = MagicMock(spec=WorkerPool) - - coordinator = ExecutionCoordinator( - graph_execution=graph_execution, - state_manager=state_manager, - command_processor=command_processor, - worker_pool=worker_pool, - ) - return coordinator, state_manager, worker_pool - - -def test_handle_pause_stops_workers_and_clears_state() -> None: - """Paused execution should stop workers and clear executing state.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - graph_execution.pause("Awaiting human input") - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_called_once_with() - state_manager.clear_executing.assert_called_once_with() - - -def test_handle_pause_noop_when_execution_running() -> None: - """Running execution should not trigger pause handling.""" - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, state_manager, worker_pool = _build_coordinator(graph_execution) - - coordinator.handle_pause_if_needed() - - worker_pool.stop.assert_not_called() - state_manager.clear_executing.assert_not_called() - - -def test_has_executing_nodes_requires_pause() -> None: - graph_execution = GraphExecution(workflow_id="workflow") - graph_execution.start() - - coordinator, _, _ = _build_coordinator(graph_execution) - - with pytest.raises(AssertionError): - coordinator.has_executing_nodes() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py deleted file mode 100644 index 4e13177d2b..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_engine.py +++ /dev/null @@ -1,770 +0,0 @@ -""" -Table-driven test framework for GraphEngine workflows. - -This file contains property-based tests and specific workflow tests. -The core test framework is in test_table_runner.py. -""" - -import time - -from hypothesis import HealthCheck, given, settings -from hypothesis import strategies as st - -from dify_graph.entities.base_node_data import DefaultValue, DefaultValueType -from dify_graph.enums import ErrorStrategy -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( - GraphRunPartialSucceededEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, -) - -# Import the test framework from the new module -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowRunner, WorkflowTestCase - - -# Property-based fuzzing tests for the start-end workflow -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_echo_workflow_property_basic_strings(query_input): - """ - Property-based test: Echo workflow should return exactly what was input. - - This tests the fundamental property that for any string input, - the start-end workflow should echo it back unchanged. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Fuzzing test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should equal input (echo behavior) - assert result.actual_outputs - assert result.actual_outputs == {"query": query_input}, ( - f"Echo property violated. Input: {repr(query_input)}, " - f"Expected: {repr(query_input)}, Got: {repr(result.actual_outputs.get('query'))}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_echo_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds to test edge cases more efficiently. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Bounded fuzzing test (len={len(query_input)})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == {"query": query_input} - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="🎉🌟💫⭐🔥💯🚀🎯"), # Emojis - st.text(alphabet="αβγδεζηθικλμνξοπρστυφχψω"), # Greek letters - st.text(alphabet="中文测试한국어日本語العربية"), # International characters - st.just(""), # Empty string - st.just(" " * 100), # Whitespace only - st.just("\n\t\r\f\v"), # Special whitespace chars - st.just('{"json": "like", "data": [1, 2, 3]}'), # JSON-like string - st.just("SELECT * FROM users; DROP TABLE users;--"), # SQL injection attempt - st.just(""), # XSS attempt - st.just("../../etc/passwd"), # Path traversal attempt - ) -) -@settings(max_examples=40, deadline=25000) -def test_echo_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types including edge cases and security payloads. - - Tests various categories of potentially problematic inputs: - - Unicode characters from different languages - - Emojis and special symbols - - Whitespace variations - - Malicious payloads (SQL injection, XSS, path traversal) - - JSON-like structures - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Diverse input fuzzing: {type(query_input).__name__}", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Echo behavior must be preserved regardless of input type - assert result.actual_outputs == {"query": query_input} - - -@given(query_input=st.text(min_size=1000, max_size=5000)) -@settings(max_examples=10, deadline=60000) -def test_echo_workflow_property_large_inputs(query_input): - """ - Property-based test for large inputs to test memory and performance boundaries. - - Tests the system's ability to handle larger payloads efficiently. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": query_input}, - expected_outputs={"query": query_input}, - description=f"Large input test (size: {len(query_input)} chars)", - timeout=45.0, # Longer timeout for large inputs - ) - - start_time = time.perf_counter() - result = runner.run_test_case(test_case) - execution_time = time.perf_counter() - start_time - - # Property: Large inputs should still work - assert result.success, f"Large input workflow failed: {result.error}" - - # Property: Echo behavior preserved for large inputs - assert result.actual_outputs == {"query": query_input} - - # Property: Performance should be reasonable even for large inputs - assert execution_time < 30.0, f"Large input took too long: {execution_time:.2f}s" - - -def test_echo_workflow_robustness_smoke_test(): - """ - Smoke test to ensure the basic workflow functionality works before fuzzing. - - This test uses a simple, known-good input to verify the test infrastructure - is working correctly before running the fuzzing tests. - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "smoke test"}, - expected_outputs={"query": "smoke test"}, - description="Smoke test for basic functionality", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Smoke test failed: {result.error}" - assert result.actual_outputs == {"query": "smoke test"} - assert result.execution_time > 0 - - -def test_if_else_workflow_true_branch(): - """ - Test if-else workflow when input contains 'hello' (true branch). - - Should output {"true": input_query} when query contains "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello world"}, - expected_outputs={"true": "hello world"}, - description="Basic hello case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "say hello to everyone"}, - expected_outputs={"true": "say hello to everyone"}, - description="Hello in middle of sentence", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello"}, - expected_outputs={"true": "hello"}, - description="Just hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hellohello"}, - expected_outputs={"true": "hellohello"}, - description="Multiple hello occurrences", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (true branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'true' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_false_branch(): - """ - Test if-else workflow when input does not contain 'hello' (false branch). - - Should output {"false": input_query} when query does not contain "hello". - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "goodbye world"}, - expected_outputs={"false": "goodbye world"}, - description="Basic goodbye case", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hi there"}, - expected_outputs={"false": "hi there"}, - description="Simple greeting without hello", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": ""}, - expected_outputs={"false": ""}, - description="Empty string", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "test message"}, - expected_outputs={"false": "test message"}, - description="Regular message", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key (false branch) - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected only 'false' key in outputs for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -def test_if_else_workflow_edge_cases(): - """ - Test if-else workflow edge cases and case sensitivity. - - Tests various edge cases including case sensitivity, similar words, etc. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "Hello world"}, - expected_outputs={"false": "Hello world"}, - description="Capitalized Hello (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "HELLO"}, - expected_outputs={"false": "HELLO"}, - description="All caps HELLO (case sensitive test)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helllo"}, - expected_outputs={"false": "helllo"}, - description="Typo: helllo (with extra l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "helo"}, - expected_outputs={"false": "helo"}, - description="Typo: helo (missing l)", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello123"}, - expected_outputs={"true": "hello123"}, - description="Hello with numbers", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": "hello!@#"}, - expected_outputs={"true": "hello!@#"}, - description="Hello with special characters", - ), - WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": " hello "}, - expected_outputs={"true": " hello "}, - description="Hello with surrounding spaces", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - for result in suite_result.results: - assert result.success, f"Test case '{result.test_case.description}' failed: {result.error}" - # Check that outputs contain ONLY the expected key - assert result.actual_outputs == result.test_case.expected_outputs, ( - f"Expected exact match for {result.test_case.description}. " - f"Expected: {result.test_case.expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text()) -@settings(max_examples=50, deadline=30000, suppress_health_check=[HealthCheck.too_slow]) -def test_if_else_workflow_property_basic_strings(query_input): - """ - Property-based test: If-else workflow should output correct branch based on 'hello' content. - - This tests the fundamental property that for any string input: - - If input contains "hello", output should be {"true": input} - - If input doesn't contain "hello", output should be {"false": input} - """ - runner = TableTestRunner() - - # Determine expected output based on whether input contains "hello" - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Property test with input: {repr(query_input)[:50]}...", - ) - - result = runner.run_test_case(test_case) - - # Property: The workflow should complete successfully - assert result.success, f"Workflow failed with input {repr(query_input)}: {result.error}" - - # Property: Output should contain ONLY the expected key with correct value - assert result.actual_outputs == expected_outputs, ( - f"If-else property violated. Input: {repr(query_input)}, " - f"Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -@given(query_input=st.text(min_size=0, max_size=1000)) -@settings(max_examples=30, deadline=20000) -def test_if_else_workflow_property_bounded_strings(query_input): - """ - Property-based test with size bounds for if-else workflow. - - Tests strings up to 1000 characters to balance thoroughness with performance. - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Bounded if-else test (len={len(query_input)}, contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Workflow failed with bounded input: {result.error}" - assert result.actual_outputs == expected_outputs - - -@given( - query_input=st.one_of( - st.text(alphabet=st.characters(whitelist_categories=["Lu", "Ll", "Nd", "Po"])), # Letters, digits, punctuation - st.text(alphabet="hello"), # Strings that definitely contain hello - st.text(alphabet="xyz"), # Strings that definitely don't contain hello - st.just("hello world"), # Known true case - st.just("goodbye world"), # Known false case - st.just(""), # Empty string - st.just("Hello"), # Case sensitivity test - st.just("HELLO"), # Case sensitivity test - st.just("hello" * 10), # Multiple hello occurrences - st.just("say hello to everyone"), # Hello in middle - st.text(alphabet="🎉🌟💫⭐🔥💯🚀🎯"), # Emojis - st.text(alphabet="中文测试한국어日本語العربية"), # International characters - ) -) -@settings(max_examples=40, deadline=25000) -def test_if_else_workflow_property_diverse_inputs(query_input): - """ - Property-based test with diverse input types for if-else workflow. - - Tests various categories including: - - Known true/false cases - - Case sensitivity scenarios - - Unicode characters from different languages - - Emojis and special symbols - - Multiple hello occurrences - """ - runner = TableTestRunner() - - contains_hello = "hello" in query_input - expected_key = "true" if contains_hello else "false" - expected_outputs = {expected_key: query_input} - - test_case = WorkflowTestCase( - fixture_path="conditional_hello_branching_workflow", - inputs={"query": query_input}, - expected_outputs=expected_outputs, - description=f"Diverse if-else test: {type(query_input).__name__} (contains_hello={contains_hello})", - ) - - result = runner.run_test_case(test_case) - - # Property: System should handle all inputs gracefully (no crashes) - assert result.success, f"Workflow failed with diverse input {repr(query_input)}: {result.error}" - - # Property: Correct branch logic must be preserved regardless of input type - assert result.actual_outputs == expected_outputs, ( - f"Branch logic violated. Input: {repr(query_input)}, " - f"Contains 'hello': {contains_hello}, Expected: {expected_outputs}, Got: {result.actual_outputs}" - ) - - -# Tests for the Layer system -def test_layer_system_basic(): - """Test basic layer functionality with DebugLoggingLayer.""" - from dify_graph.graph_engine.layers import DebugLoggingLayer - - runner = WorkflowRunner() - - # Load a simple echo workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test layer system"}) - - # Create engine with layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add debug logging layer - debug_layer = DebugLoggingLayer(level="DEBUG", include_inputs=True, include_outputs=True) - engine.layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify events were generated - assert len(events) > 0 - assert isinstance(events[0], GraphRunStartedEvent) - assert isinstance(events[-1], GraphRunSucceededEvent) - - # Verify layer received context - assert debug_layer.graph_runtime_state is not None - assert debug_layer.command_channel is not None - - # Verify layer tracked execution stats - assert debug_layer.node_count > 0 - assert debug_layer.success_count > 0 - - -def test_layer_chaining(): - """Test chaining multiple layers.""" - from dify_graph.graph_engine.layers import DebugLoggingLayer, GraphEngineLayer - - # Create a custom test layer - class TestLayer(GraphEngineLayer): - def __init__(self): - super().__init__() - self.events_received = [] - self.graph_started = False - self.graph_ended = False - - def on_graph_start(self): - self.graph_started = True - - def on_event(self, event): - self.events_received.append(event.__class__.__name__) - - def on_graph_end(self, error): - self.graph_ended = True - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test chaining"}) - - # Create engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Chain multiple layers - test_layer = TestLayer() - debug_layer = DebugLoggingLayer(level="INFO") - - engine.layer(test_layer).layer(debug_layer) - - # Run workflow - events = list(engine.run()) - - # Verify both layers received events - assert test_layer.graph_started - assert test_layer.graph_ended - assert len(test_layer.events_received) > 0 - - # Verify debug layer also worked - assert debug_layer.node_count > 0 - - -def test_layer_error_handling(): - """Test that layer errors don't crash the engine.""" - from dify_graph.graph_engine.layers import GraphEngineLayer - - # Create a layer that throws errors - class FaultyLayer(GraphEngineLayer): - def on_graph_start(self): - raise RuntimeError("Intentional error in on_graph_start") - - def on_event(self, event): - raise RuntimeError("Intentional error in on_event") - - def on_graph_end(self, error): - raise RuntimeError("Intentional error in on_graph_end") - - runner = WorkflowRunner() - - # Load workflow - fixture_data = runner.load_fixture("simple_passthrough_workflow") - graph, graph_runtime_state = runner.create_graph_from_fixture(fixture_data, inputs={"query": "test error handling"}) - - # Create engine with faulty layer - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Add faulty layer - engine.layer(FaultyLayer()) - - # Run workflow - should not crash despite layer errors - events = list(engine.run()) - - # Verify workflow still completed successfully - assert len(events) > 0 - assert isinstance(events[-1], GraphRunSucceededEvent) - assert events[-1].outputs == {"query": "test error handling"} - - -def test_event_sequence_validation(): - """Test the new event sequence validation feature.""" - from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - # Test 1: Successful event sequence validation - test_case_success = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test event sequence"}, - expected_outputs={"query": "test event sequence"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, # Start node begins - NodeRunStreamChunkEvent, # Start node streaming - NodeRunSucceededEvent, # Start node completes - NodeRunStartedEvent, # End node begins - NodeRunSucceededEvent, # End node completes - GraphRunSucceededEvent, # Graph completes - ], - description="Test with correct event sequence", - ) - - result = runner.run_test_case(test_case_success) - assert result.success, f"Test should pass with correct event sequence. Error: {result.event_mismatch_details}" - assert result.event_sequence_match is True - assert result.event_mismatch_details is None - - # Test 2: Failed event sequence validation - wrong order - test_case_wrong_order = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong order"}, - expected_outputs={"query": "test wrong order"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunSucceededEvent, # Wrong: expecting success before start - NodeRunStreamChunkEvent, - NodeRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Test with incorrect event order", - ) - - result = runner.run_test_case(test_case_wrong_order) - assert not result.success, "Test should fail with incorrect event sequence" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event mismatch at position" in result.event_mismatch_details - - # Test 3: Failed event sequence validation - wrong count - test_case_wrong_count = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test wrong count"}, - expected_outputs={"query": "test wrong count"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Missing the second node's events - GraphRunSucceededEvent, - ], - description="Test with incorrect event count", - ) - - result = runner.run_test_case(test_case_wrong_count) - assert not result.success, "Test should fail with incorrect event count" - assert result.event_sequence_match is False - assert result.event_mismatch_details is not None - assert "Event count mismatch" in result.event_mismatch_details - - # Test 4: No event sequence validation (backward compatibility) - test_case_no_validation = WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test no validation"}, - expected_outputs={"query": "test no validation"}, - # No expected_event_sequence provided - description="Test without event sequence validation", - ) - - result = runner.run_test_case(test_case_no_validation) - assert result.success, "Test should pass when no event sequence is provided" - assert result.event_sequence_match is None - assert result.event_mismatch_details is None - - -def test_event_sequence_validation_with_table_tests(): - """Test event sequence validation with table-driven tests.""" - from dify_graph.graph_events import NodeRunStartedEvent, NodeRunStreamChunkEvent, NodeRunSucceededEvent - - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test1"}, - expected_outputs={"query": "test1"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 1: Valid sequence", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test2"}, - expected_outputs={"query": "test2"}, - # No event sequence validation for this test - description="Table test 2: No sequence validation", - ), - WorkflowTestCase( - fixture_path="simple_passthrough_workflow", - inputs={"query": "test3"}, - expected_outputs={"query": "test3"}, - expected_event_sequence=[ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - description="Table test 3: Valid sequence", - ), - ] - - suite_result = runner.run_table_tests(test_cases) - - # Check all tests passed - for i, result in enumerate(suite_result.results): - if i == 1: # Test 2 has no event sequence validation - assert result.event_sequence_match is None - else: - assert result.event_sequence_match is True - assert result.success, f"Test {i + 1} failed: {result.event_mismatch_details or result.error}" - - -def test_graph_run_emits_partial_success_when_node_failure_recovered(): - runner = TableTestRunner() - - fixture_data = runner.workflow_runner.load_fixture("basic_chatflow") - mock_config = MockConfigBuilder().with_node_error("llm", "mock llm failure").build() - - graph, graph_runtime_state = runner.workflow_runner.create_graph_from_fixture( - fixture_data=fixture_data, - query="hello", - use_mock_factory=True, - mock_config=mock_config, - ) - - llm_node = graph.nodes["llm"] - base_node_data = llm_node.node_data - base_node_data.error_strategy = ErrorStrategy.DEFAULT_VALUE - base_node_data.default_value = [DefaultValue(key="text", value="fallback response", type=DefaultValueType.STRING)] - - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - events = list(engine.run()) - - assert isinstance(events[-1], GraphRunPartialSucceededEvent) - - partial_event = next(event for event in events if isinstance(event, GraphRunPartialSucceededEvent)) - assert partial_event.exceptions_count == 1 - assert partial_event.outputs.get("answer") == "fallback response" - - assert not any(isinstance(event, GraphRunSucceededEvent) for event in events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py deleted file mode 100644 index 255784b77d..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_execution_serialization.py +++ /dev/null @@ -1,196 +0,0 @@ -"""Unit tests for GraphExecution serialization helpers.""" - -from __future__ import annotations - -import json -from collections import deque -from unittest.mock import MagicMock - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState -from dify_graph.graph_engine.domain import GraphExecution -from dify_graph.graph_engine.response_coordinator import ResponseStreamCoordinator -from dify_graph.graph_engine.response_coordinator.path import Path -from dify_graph.graph_engine.response_coordinator.session import ResponseSession -from dify_graph.graph_events import NodeRunStreamChunkEvent -from dify_graph.nodes.base.template import Template, TextSegment, VariableSegment - - -class CustomGraphExecutionError(Exception): - """Custom exception used to verify error serialization.""" - - -def test_graph_execution_serialization_round_trip() -> None: - """GraphExecution serialization restores full aggregate state.""" - # Arrange - execution = GraphExecution(workflow_id="wf-1") - execution.start() - node_a = execution.get_or_create_node_execution("node-a") - node_a.mark_started(execution_id="exec-1") - node_a.increment_retry() - node_a.mark_failed("boom") - node_b = execution.get_or_create_node_execution("node-b") - node_b.mark_skipped() - execution.fail(CustomGraphExecutionError("serialization failure")) - - # Act - serialized = execution.dumps() - payload = json.loads(serialized) - restored = GraphExecution(workflow_id="wf-1") - restored.loads(serialized) - - # Assert - assert payload["type"] == "GraphExecution" - assert payload["version"] == "1.0" - assert restored.workflow_id == "wf-1" - assert restored.started is True - assert restored.completed is True - assert restored.aborted is False - assert isinstance(restored.error, CustomGraphExecutionError) - assert str(restored.error) == "serialization failure" - assert set(restored.node_executions) == {"node-a", "node-b"} - restored_node_a = restored.node_executions["node-a"] - assert restored_node_a.state is NodeState.TAKEN - assert restored_node_a.retry_count == 1 - assert restored_node_a.execution_id == "exec-1" - assert restored_node_a.error == "boom" - restored_node_b = restored.node_executions["node-b"] - assert restored_node_b.state is NodeState.SKIPPED - assert restored_node_b.retry_count == 0 - assert restored_node_b.execution_id is None - assert restored_node_b.error is None - - -def test_graph_execution_loads_replaces_existing_state() -> None: - """loads replaces existing runtime data with serialized snapshot.""" - # Arrange - source = GraphExecution(workflow_id="wf-2") - source.start() - source_node = source.get_or_create_node_execution("node-source") - source_node.mark_taken() - serialized = source.dumps() - - target = GraphExecution(workflow_id="wf-2") - target.start() - target.abort("pre-existing abort") - temp_node = target.get_or_create_node_execution("node-temp") - temp_node.increment_retry() - temp_node.mark_failed("temp error") - - # Act - target.loads(serialized) - - # Assert - assert target.aborted is False - assert target.error is None - assert target.started is True - assert target.completed is False - assert set(target.node_executions) == {"node-source"} - restored_node = target.node_executions["node-source"] - assert restored_node.state is NodeState.TAKEN - assert restored_node.retry_count == 0 - assert restored_node.execution_id is None - assert restored_node.error is None - - -def test_response_stream_coordinator_serialization_round_trip(monkeypatch) -> None: - """ResponseStreamCoordinator serialization restores coordinator internals.""" - - template_main = Template(segments=[TextSegment(text="Hi "), VariableSegment(selector=["node-source", "text"])]) - template_secondary = Template(segments=[TextSegment(text="secondary")]) - - class DummyNode: - def __init__(self, node_id: str, template: Template, execution_type: NodeExecutionType) -> None: - self.id = node_id - self.node_type = ( - BuiltinNodeTypes.ANSWER if execution_type == NodeExecutionType.RESPONSE else BuiltinNodeTypes.LLM - ) - self.execution_type = execution_type - self.state = NodeState.UNKNOWN - self.title = node_id - self.template = template - - def blocks_variable_output(self, *_args) -> bool: - return False - - response_node1 = DummyNode("response-1", template_main, NodeExecutionType.RESPONSE) - response_node2 = DummyNode("response-2", template_main, NodeExecutionType.RESPONSE) - response_node3 = DummyNode("response-3", template_main, NodeExecutionType.RESPONSE) - source_node = DummyNode("node-source", template_secondary, NodeExecutionType.EXECUTABLE) - - class DummyGraph: - def __init__(self) -> None: - self.nodes = { - response_node1.id: response_node1, - response_node2.id: response_node2, - response_node3.id: response_node3, - source_node.id: source_node, - } - self.edges: dict[str, object] = {} - self.root_node = response_node1 - - def get_outgoing_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - def get_incoming_edges(self, _node_id: str): # pragma: no cover - not exercised - return [] - - graph = DummyGraph() - - def fake_from_node(cls, node: DummyNode) -> ResponseSession: - return ResponseSession(node_id=node.id, template=node.template) - - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - - coordinator = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - coordinator._response_nodes = {"response-1", "response-2", "response-3"} - coordinator._paths_maps = { - "response-1": [Path(edges=["edge-1"])], - "response-2": [Path(edges=[])], - "response-3": [Path(edges=["edge-2", "edge-3"])], - } - - active_session = ResponseSession(node_id="response-1", template=response_node1.template) - active_session.index = 1 - coordinator._active_session = active_session - waiting_session = ResponseSession(node_id="response-2", template=response_node2.template) - coordinator._waiting_sessions = deque([waiting_session]) - pending_session = ResponseSession(node_id="response-3", template=response_node3.template) - pending_session.index = 2 - coordinator._response_sessions = {"response-3": pending_session} - - coordinator._node_execution_ids = {"response-1": "exec-1"} - event = NodeRunStreamChunkEvent( - id="exec-1", - node_id="response-1", - node_type=BuiltinNodeTypes.ANSWER, - selector=["node-source", "text"], - chunk="chunk-1", - is_final=False, - ) - coordinator._stream_buffers = {("node-source", "text"): [event]} - coordinator._stream_positions = {("node-source", "text"): 1} - coordinator._closed_streams = {("node-source", "text")} - - serialized = coordinator.dumps() - - restored = ResponseStreamCoordinator(variable_pool=MagicMock(), graph=graph) # type: ignore[arg-type] - monkeypatch.setattr(ResponseSession, "from_node", classmethod(fake_from_node)) - restored.loads(serialized) - - assert restored._response_nodes == {"response-1", "response-2", "response-3"} - assert restored._paths_maps["response-1"][0].edges == ["edge-1"] - assert restored._active_session is not None - assert restored._active_session.node_id == "response-1" - assert restored._active_session.index == 1 - waiting_restored = list(restored._waiting_sessions) - assert len(waiting_restored) == 1 - assert waiting_restored[0].node_id == "response-2" - assert waiting_restored[0].index == 0 - assert set(restored._response_sessions) == {"response-3"} - assert restored._response_sessions["response-3"].index == 2 - assert restored._node_execution_ids == {"response-1": "exec-1"} - assert ("node-source", "text") in restored._stream_buffers - restored_event = restored._stream_buffers[("node-source", "text")][0] - assert restored_event.chunk == "chunk-1" - assert restored._stream_positions[("node-source", "text")] == 1 - assert ("node-source", "text") in restored._closed_streams diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py b/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py deleted file mode 100644 index d54f0be190..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_graph_state_snapshot.py +++ /dev/null @@ -1,190 +0,0 @@ -import time -from collections.abc import Mapping - -from dify_graph.entities import GraphInitParams -from dify_graph.enums import NodeState -from dify_graph.graph import Graph -from dify_graph.graph_engine.graph_state_manager import GraphStateManager -from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_llm_node( - *, - node_id: str, - runtime_state: GraphRuntimeState, - graph_init_params: GraphInitParams, - mock_config: MockConfig, -) -> MockLLMNode: - llm_data = LLMNodeData( - title=f"LLM {node_id}", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=f"Prompt {node_id}", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - return MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - -def _build_graph(runtime_state: GraphRuntimeState) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - mock_config = MockConfig() - llm_a = _build_llm_node( - node_id="llm_a", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - llm_b = _build_llm_node( - node_id="llm_b", - runtime_state=runtime_state, - graph_init_params=graph_init_params, - mock_config=mock_config, - ) - - end_data = EndNodeData(title="End", outputs=[], desc=None) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - builder = ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(llm_b, from_node_id="start") - .add_node(end_node, from_node_id="llm_a") - ) - return builder.connect(tail="llm_b", head="end").build() - - -def _edge_state_map(graph: Graph) -> Mapping[tuple[str, str, str], NodeState]: - return {(edge.tail, edge.head, edge.source_handle): edge.state for edge in graph.edges.values()} - - -def test_runtime_state_snapshot_restores_graph_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - graph.nodes["llm_a"].state = NodeState.TAKEN - graph.nodes["llm_b"].state = NodeState.SKIPPED - - for edge in graph.edges.values(): - if edge.tail == "start" and edge.head == "llm_a": - edge.state = NodeState.TAKEN - elif edge.tail == "start" and edge.head == "llm_b": - edge.state = NodeState.SKIPPED - elif edge.head == "end" and edge.tail == "llm_a": - edge.state = NodeState.TAKEN - elif edge.head == "end" and edge.tail == "llm_b": - edge.state = NodeState.SKIPPED - - snapshot = runtime_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - assert resumed_graph.nodes["llm_a"].state == NodeState.TAKEN - assert resumed_graph.nodes["llm_b"].state == NodeState.SKIPPED - assert _edge_state_map(resumed_graph) == _edge_state_map(graph) - - -def test_join_readiness_uses_restored_edge_states() -> None: - runtime_state = _build_runtime_state() - graph = _build_graph(runtime_state) - runtime_state.attach_graph(graph) - - ready_queue = InMemoryReadyQueue() - state_manager = GraphStateManager(graph, ready_queue) - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_a": - edge.state = NodeState.TAKEN - if edge.tail == "llm_b": - edge.state = NodeState.UNKNOWN - - assert state_manager.is_node_ready("end") is False - - for edge in graph.get_incoming_edges("end"): - if edge.tail == "llm_b": - edge.state = NodeState.TAKEN - - assert state_manager.is_node_ready("end") is True - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resumed_graph = _build_graph(resumed_state) - resumed_state.attach_graph(resumed_graph) - - resumed_state_manager = GraphStateManager(resumed_graph, InMemoryReadyQueue()) - assert resumed_state_manager.is_node_ready("end") is True diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py deleted file mode 100644 index 538f53c603..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_multi_branch.py +++ /dev/null @@ -1,387 +0,0 @@ -import datetime -import time -from collections.abc import Iterable -from unittest import mock -from unittest.mock import MagicMock - -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_branching_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="primary", title="Primary"), - UserAction(id="secondary", title="Secondary"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(human_node) - .add_node(llm_primary, from_node_id="human", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="human", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def _assert_stream_chunk_sequence( - chunk_events: Iterable[NodeRunStreamChunkEvent], - expected_nodes: list[str], - expected_chunks: list[str], -) -> None: - actual_nodes = [event.node_id for event in chunk_events] - actual_chunks = [event.chunk for event in chunk_events] - assert actual_nodes == expected_nodes - assert actual_chunks == expected_chunks - - -def test_human_input_llm_streaming_across_multiple_branches() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - branch_scenarios = [ - { - "handle": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_primary", ["\n"]), # literal segment emitted when end_primary session activates - ], - "expected_post_chunks": [ - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), # live stream from chosen branch - ], - }, - { - "handle": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_pre_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), # cached output before branch completes - ("end_secondary", ["\n"]), # literal segment emitted when end_secondary session activates - ], - "expected_post_chunks": [ - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), # live stream from chosen branch - ], - }, - ] - - for scenario in branch_scenarios: - runner = TableTestRunner() - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def initial_graph_factory(mock_create_repo=mock_create_repo) -> tuple[Graph, GraphRuntimeState]: - return _build_branching_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause before branching decision", - graph_factory=initial_graph_factory, - expected_event_sequence=[ - GraphRunStartedEvent, # initial run: graph execution starts - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and issues pause - NodeRunPauseRequestedEvent, # human node requests pause awaiting input - GraphRunPausedEvent, # graph run pauses awaiting resume - ], - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - assert not any(isinstance(event, NodeRunStreamChunkEvent) for event in initial_result.events) - - pre_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_pre_chunks"]) - post_chunk_count = sum(len(chunks) for _, chunks in scenario["expected_post_chunks"]) - expected_pre_chunk_events_in_resumption = [ - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunHumanInputFormFilledEvent, - ] - - expected_resume_sequence: list[type] = ( - expected_pre_chunk_events_in_resumption - + [NodeRunStreamChunkEvent] * pre_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - ] - + [NodeRunStreamChunkEvent] * post_chunk_count - + [ - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ] - ) - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = scenario["handle"] - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory( - initial_result=initial_result, mock_get_repo=mock_get_repo - ) -> tuple[Graph, GraphRuntimeState]: - assert initial_result.graph_runtime_state is not None - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_branching_graph(mock_config, mock_get_repo, resume_runtime_state) - - resume_case = WorkflowTestCase( - description=f"HumanInput resumes via {scenario['handle']} branch", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert len(chunk_events) == pre_chunk_count + post_chunk_count - - pre_chunk_events = chunk_events[:pre_chunk_count] - post_chunk_events = chunk_events[pre_chunk_count:] - - expected_pre_nodes: list[str] = [] - expected_pre_chunks: list[str] = [] - for node_id, chunks in scenario["expected_pre_chunks"]: - expected_pre_nodes.extend([node_id] * len(chunks)) - expected_pre_chunks.extend(chunks) - _assert_stream_chunk_sequence(pre_chunk_events, expected_pre_nodes, expected_pre_chunks) - - expected_post_nodes: list[str] = [] - expected_post_chunks: list[str] = [] - for node_id, chunks in scenario["expected_post_chunks"]: - expected_post_nodes.extend([node_id] * len(chunks)) - expected_post_chunks.extend(chunks) - _assert_stream_chunk_sequence(post_chunk_events, expected_post_nodes, expected_post_chunks) - - human_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - pre_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and index < human_success_index - ] - expected_pre_chunk_events_count_in_resumption = len(expected_pre_chunk_events_in_resumption) - assert pre_indices == list(range(expected_pre_chunk_events_count_in_resumption, human_success_index)) - - resume_chunk_indices = [ - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices, "Expected streaming output from the selected branch" - resume_start_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py b/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py deleted file mode 100644 index 36bba6deb6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_human_input_pause_single_branch.py +++ /dev/null @@ -1,344 +0,0 @@ -import datetime -import time -from unittest import mock -from unittest.mock import MagicMock - -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.graph_events.node import NodeRunHumanInputFormFilledEvent -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormEntity, HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_llm_human_llm_graph( - mock_config: MockConfig, - form_repository: HumanInputFormRepository, - graph_runtime_state: GraphRuntimeState | None = None, -) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - if graph_runtime_state is None: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", app_id="app", workflow_id="workflow", workflow_execution_id="test-execution-id," - ), - user_inputs={}, - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_first = _create_llm_node("llm_initial", "Initial LLM", "Initial prompt") - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[ - UserAction(id="accept", title="Accept"), - UserAction(id="reject", title="Reject"), - ], - ) - - human_config = {"id": "human", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - form_repository=form_repository, - ) - - llm_second = _create_llm_node("llm_resume", "Follow-up LLM", "Follow-up prompt") - - end_data = EndNodeData( - title="End", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="resume_text", value_type=OutputVariableType.STRING, value_selector=["llm_resume", "text"] - ), - ], - desc=None, - ) - end_config = {"id": "end", "data": end_data.model_dump()} - end_node = EndNode( - id=end_config["id"], - config=end_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_first) - .add_node(human_node) - .add_node(llm_second, source_handle="accept") - .add_node(end_node) - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_human_input_llm_streaming_order_across_pause() -> None: - runner = TableTestRunner() - - initial_text = "Hello, pause" - resume_text = "Welcome back!" - - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": initial_text}) - mock_config.set_node_outputs("llm_resume", {"text": resume_text}) - - expected_initial_sequence: list[type] = [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial begins streaming - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # human node begins and requests pause - NodeRunPauseRequestedEvent, # human node pause requested - GraphRunPausedEvent, # graph run pauses awaiting resume - ] - - mock_create_repo = MagicMock(spec=HumanInputFormRepository) - mock_create_repo.get_form.return_value = None - mock_form_entity = MagicMock(spec=HumanInputFormEntity) - mock_form_entity.id = "test_form_id" - mock_form_entity.web_app_token = "test_web_app_token" - mock_form_entity.recipients = [] - mock_form_entity.rendered_content = "rendered" - mock_form_entity.submitted = False - mock_create_repo.create_form.return_value = mock_form_entity - - def graph_factory() -> tuple[Graph, GraphRuntimeState]: - return _build_llm_human_llm_graph(mock_config, mock_create_repo) - - initial_case = WorkflowTestCase( - description="HumanInput pause preserves LLM streaming order", - graph_factory=graph_factory, - expected_event_sequence=expected_initial_sequence, - ) - - initial_result = runner.run_test_case(initial_case) - - assert initial_result.success, initial_result.event_mismatch_details - - initial_events = initial_result.events - initial_chunks = _expected_mock_llm_chunks(initial_text) - - initial_stream_chunk_events = [event for event in initial_events if isinstance(event, NodeRunStreamChunkEvent)] - assert initial_stream_chunk_events == [] - - pause_index = next(i for i, event in enumerate(initial_events) if isinstance(event, GraphRunPausedEvent)) - llm_succeeded_index = next( - i - for i, event in enumerate(initial_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_initial" - ) - assert llm_succeeded_index < pause_index - - graph_runtime_state = initial_result.graph_runtime_state - graph = initial_result.graph - assert graph_runtime_state is not None - assert graph is not None - - coordinator = graph_runtime_state.response_coordinator - stream_buffers = coordinator._stream_buffers # Tests may access internals for assertions - assert ("llm_initial", "text") in stream_buffers - initial_stream_chunks = [event.chunk for event in stream_buffers[("llm_initial", "text")]] - assert initial_stream_chunks == initial_chunks - assert ("llm_resume", "text") not in stream_buffers - - resume_chunks = _expected_mock_llm_chunks(resume_text) - expected_resume_sequence: list[type] = [ - GraphRunStartedEvent, # resumed graph run begins - NodeRunStartedEvent, # human node restarts - # Form Filled should be generated first, then the node execution ends and stream chunk is generated. - NodeRunHumanInputFormFilledEvent, - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 - NodeRunStreamChunkEvent, # cached llm_initial final chunk - NodeRunStreamChunkEvent, # end node emits combined template separator - NodeRunSucceededEvent, # human node finishes instantly after input - NodeRunStartedEvent, # llm_resume begins streaming - NodeRunStreamChunkEvent, # llm_resume chunk 1 - NodeRunStreamChunkEvent, # llm_resume chunk 2 - NodeRunStreamChunkEvent, # llm_resume final chunk - NodeRunSucceededEvent, # llm_resume completes streaming - NodeRunStartedEvent, # end node starts - NodeRunSucceededEvent, # end node finishes - GraphRunSucceededEvent, # graph run succeeds after resume - ] - - mock_get_repo = MagicMock(spec=HumanInputFormRepository) - submitted_form = MagicMock(spec=HumanInputFormEntity) - submitted_form.id = mock_form_entity.id - submitted_form.web_app_token = mock_form_entity.web_app_token - submitted_form.recipients = [] - submitted_form.rendered_content = mock_form_entity.rendered_content - submitted_form.submitted = True - submitted_form.selected_action_id = "accept" - submitted_form.submitted_data = {} - submitted_form.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - mock_get_repo.get_form.return_value = submitted_form - - def resume_graph_factory() -> tuple[Graph, GraphRuntimeState]: - # restruct the graph runtime state - serialized_runtime_state = initial_result.graph_runtime_state.dumps() - resume_runtime_state = GraphRuntimeState.from_snapshot(serialized_runtime_state) - return _build_llm_human_llm_graph( - mock_config, - mock_get_repo, - resume_runtime_state, - ) - - resume_case = WorkflowTestCase( - description="HumanInput resume continues LLM streaming order", - graph_factory=resume_graph_factory, - expected_event_sequence=expected_resume_sequence, - ) - - resume_result = runner.run_test_case(resume_case) - - assert resume_result.success, resume_result.event_mismatch_details - - resume_events = resume_result.events - - success_index = next(i for i, event in enumerate(resume_events) if isinstance(event, GraphRunSucceededEvent)) - llm_resume_succeeded_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - assert llm_resume_succeeded_index < success_index - - resume_chunk_events = [event for event in resume_events if isinstance(event, NodeRunStreamChunkEvent)] - assert [event.node_id for event in resume_chunk_events[:3]] == ["llm_initial"] * 3 - assert [event.chunk for event in resume_chunk_events[:3]] == initial_chunks - assert resume_chunk_events[3].node_id == "end" - assert resume_chunk_events[3].chunk == "\n" - assert [event.node_id for event in resume_chunk_events[4:]] == ["llm_resume"] * 3 - assert [event.chunk for event in resume_chunk_events[4:]] == resume_chunks - - human_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "human" - ) - cached_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id in {"llm_initial", "end"} - ] - assert all(index < human_success_index for index in cached_chunk_indices) - - llm_resume_start_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "llm_resume" - ) - llm_resume_success_index = next( - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "llm_resume" - ) - llm_resume_chunk_indices = [ - i - for i, event in enumerate(resume_events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == "llm_resume" - ] - assert llm_resume_chunk_indices - first_resume_chunk_index = min(llm_resume_chunk_indices) - last_resume_chunk_index = max(llm_resume_chunk_indices) - assert llm_resume_start_index < first_resume_chunk_index - assert last_resume_chunk_index < llm_resume_success_index - - started_nodes = [event.node_id for event in resume_events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["human", "llm_resume", "end"] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py b/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py deleted file mode 100644 index 8da179c15e..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_if_else_streaming.py +++ /dev/null @@ -1,324 +0,0 @@ -import time -from unittest import mock - -from dify_graph.graph import Graph -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.base.entities import OutputVariableEntity, OutputVariableType -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.nodes.if_else.if_else_node import IfElseNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.utils.condition.entities import Condition -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig -from .test_mock_nodes import MockLLMNode -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Graph, GraphRuntimeState]: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - graph_config=graph_config, - user_from="account", - invoke_from="debugger", - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(user_id="user", app_id="app", workflow_id="workflow"), - user_inputs={}, - conversation_variables=[], - ) - variable_pool.add(("branch", "value"), branch_value) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode: - llm_data = LLMNodeData( - title=title, - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text=prompt_text, - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - ) - llm_config = {"id": node_id, "data": llm_data.model_dump()} - llm_node = MockLLMNode( - id=node_id, - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - credentials_provider=mock.Mock(), - model_factory=mock.Mock(), - ) - return llm_node - - llm_initial = _create_llm_node("llm_initial", "Initial LLM", "Initial stream") - - if_else_data = IfElseNodeData( - title="IfElse", - cases=[ - IfElseNodeData.Case( - case_id="primary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="primary") - ], - ), - IfElseNodeData.Case( - case_id="secondary", - logical_operator="and", - conditions=[ - Condition(variable_selector=["branch", "value"], comparison_operator="is", value="secondary") - ], - ), - ], - ) - if_else_config = {"id": "if_else", "data": if_else_data.model_dump()} - if_else_node = IfElseNode( - id=if_else_config["id"], - config=if_else_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - llm_primary = _create_llm_node("llm_primary", "Primary LLM", "Primary stream output") - llm_secondary = _create_llm_node("llm_secondary", "Secondary LLM", "Secondary") - - end_primary_data = EndNodeData( - title="End Primary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="primary_text", value_type=OutputVariableType.STRING, value_selector=["llm_primary", "text"] - ), - ], - desc=None, - ) - end_primary_config = {"id": "end_primary", "data": end_primary_data.model_dump()} - end_primary = EndNode( - id=end_primary_config["id"], - config=end_primary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - end_secondary_data = EndNodeData( - title="End Secondary", - outputs=[ - OutputVariableEntity( - variable="initial_text", value_type=OutputVariableType.STRING, value_selector=["llm_initial", "text"] - ), - OutputVariableEntity( - variable="secondary_text", - value_type=OutputVariableType.STRING, - value_selector=["llm_secondary", "text"], - ), - ], - desc=None, - ) - end_secondary_config = {"id": "end_secondary", "data": end_secondary_data.model_dump()} - end_secondary = EndNode( - id=end_secondary_config["id"], - config=end_secondary_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - graph = ( - Graph.new() - .add_root(start_node) - .add_node(llm_initial) - .add_node(if_else_node) - .add_node(llm_primary, from_node_id="if_else", source_handle="primary") - .add_node(end_primary, from_node_id="llm_primary") - .add_node(llm_secondary, from_node_id="if_else", source_handle="secondary") - .add_node(end_secondary, from_node_id="llm_secondary") - .build() - ) - return graph, graph_runtime_state - - -def _expected_mock_llm_chunks(text: str) -> list[str]: - chunks: list[str] = [] - for index, word in enumerate(text.split(" ")): - chunk = word if index == 0 else f" {word}" - chunks.append(chunk) - chunks.append("") - return chunks - - -def test_if_else_llm_streaming_order() -> None: - mock_config = MockConfig() - mock_config.set_node_outputs("llm_initial", {"text": "Initial stream"}) - mock_config.set_node_outputs("llm_primary", {"text": "Primary stream output"}) - mock_config.set_node_outputs("llm_secondary", {"text": "Secondary"}) - - scenarios = [ - { - "branch": "primary", - "resume_llm": "llm_primary", - "end_node": "end_primary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_primary begins streaming - NodeRunStreamChunkEvent, # llm_primary chunk 1 - NodeRunStreamChunkEvent, # llm_primary chunk 2 - NodeRunStreamChunkEvent, # llm_primary chunk 3 - NodeRunStreamChunkEvent, # llm_primary final chunk - NodeRunSucceededEvent, # llm_primary completes streaming - NodeRunStartedEvent, # end_primary node starts - NodeRunSucceededEvent, # end_primary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_primary", ["\n"]), - ("llm_primary", _expected_mock_llm_chunks("Primary stream output")), - ], - }, - { - "branch": "secondary", - "resume_llm": "llm_secondary", - "end_node": "end_secondary", - "expected_sequence": [ - GraphRunStartedEvent, # graph run begins - NodeRunStartedEvent, # start node begins execution - NodeRunSucceededEvent, # start node completes - NodeRunStartedEvent, # llm_initial starts and streams - NodeRunSucceededEvent, # llm_initial completes streaming - NodeRunStartedEvent, # if_else evaluates conditions - NodeRunStreamChunkEvent, # cached llm_initial chunk 1 flushed - NodeRunStreamChunkEvent, # cached llm_initial chunk 2 flushed - NodeRunStreamChunkEvent, # cached llm_initial final chunk flushed - NodeRunStreamChunkEvent, # template literal newline emitted - NodeRunSucceededEvent, # if_else completes branch selection - NodeRunStartedEvent, # llm_secondary begins streaming - NodeRunStreamChunkEvent, # llm_secondary chunk 1 - NodeRunStreamChunkEvent, # llm_secondary final chunk - NodeRunSucceededEvent, # llm_secondary completes - NodeRunStartedEvent, # end_secondary node starts - NodeRunSucceededEvent, # end_secondary finishes aggregation - GraphRunSucceededEvent, # graph run succeeds - ], - "expected_chunks": [ - ("llm_initial", _expected_mock_llm_chunks("Initial stream")), - ("end_secondary", ["\n"]), - ("llm_secondary", _expected_mock_llm_chunks("Secondary")), - ], - }, - ] - - for scenario in scenarios: - runner = TableTestRunner() - - def graph_factory( - branch_value: str = scenario["branch"], - cfg: MockConfig = mock_config, - ) -> tuple[Graph, GraphRuntimeState]: - return _build_if_else_graph(branch_value, cfg) - - test_case = WorkflowTestCase( - description=f"IfElse streaming via {scenario['branch']} branch", - graph_factory=graph_factory, - expected_event_sequence=scenario["expected_sequence"], - ) - - result = runner.run_test_case(test_case) - - assert result.success, result.event_mismatch_details - - chunk_events = [event for event in result.events if isinstance(event, NodeRunStreamChunkEvent)] - expected_nodes: list[str] = [] - expected_chunks: list[str] = [] - for node_id, chunks in scenario["expected_chunks"]: - expected_nodes.extend([node_id] * len(chunks)) - expected_chunks.extend(chunks) - assert [event.node_id for event in chunk_events] == expected_nodes - assert [event.chunk for event in chunk_events] == expected_chunks - - branch_node_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == "if_else" - ) - branch_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == "if_else" - ) - pre_branch_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and index < branch_success_index - ] - assert len(pre_branch_chunk_indices) == len(_expected_mock_llm_chunks("Initial stream")) + 1 - assert min(pre_branch_chunk_indices) == branch_node_index + 1 - assert max(pre_branch_chunk_indices) < branch_success_index - - resume_chunk_indices = [ - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStreamChunkEvent) and event.node_id == scenario["resume_llm"] - ] - assert resume_chunk_indices - resume_start_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunStartedEvent) and event.node_id == scenario["resume_llm"] - ) - resume_success_index = next( - index - for index, event in enumerate(result.events) - if isinstance(event, NodeRunSucceededEvent) and event.node_id == scenario["resume_llm"] - ) - assert resume_start_index < min(resume_chunk_indices) - assert max(resume_chunk_indices) < resume_success_index - - started_nodes = [event.node_id for event in result.events if isinstance(event, NodeRunStartedEvent)] - assert started_nodes == ["start", "llm_initial", "if_else", scenario["resume_llm"], scenario["end_node"]] diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py b/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py deleted file mode 100644 index b9bf4be13a..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_iteration_flatten_output.py +++ /dev/null @@ -1,126 +0,0 @@ -""" -Test cases for the Iteration node's flatten_output functionality. - -This module tests the iteration node's ability to: -1. Flatten array outputs when flatten_output=True (default) -2. Preserve nested array structure when flatten_output=False -""" - -from .test_database_utils import skip_if_database_unavailable -from .test_mock_config import MockConfigBuilder, NodeMockConfig -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def _create_iteration_mock_config(): - """Helper to create a mock config for iteration tests.""" - - def code_inner_handler(node): - pool = node.graph_runtime_state.variable_pool - item_seg = pool.get(["iteration_node", "item"]) - if item_seg is not None: - item = item_seg.to_object() - return {"result": [item, item * 2]} - # This fallback is likely unreachable, but if it is, - # it doesn't simulate iteration with different values as the comment suggests. - return {"result": [1, 2]} - - return ( - MockConfigBuilder() - .with_node_output("code_node", {"result": [1, 2, 3]}) - .with_node_config(NodeMockConfig(node_id="code_inner_node", custom_handler=code_inner_handler)) - .build() - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_enabled(): - """ - Test iteration node with flatten_output=True (default behavior). - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=True, should output [1, 2, 2, 4, 3, 6] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="Iteration with flatten_output=True flattens nested arrays", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [1, 2, 2, 4, 3, 6]}, ( - f"Expected flattened output [1, 2, 2, 4, 3, 6], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_with_flatten_output_disabled(): - """ - Test iteration node with flatten_output=False. - - The fixture implements an iteration that: - 1. Iterates over [1, 2, 3] - 2. For each item, outputs [item, item*2] - 3. With flatten_output=False, should output [[1, 2], [2, 4], [3, 6]] - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="Iteration with flatten_output=False preserves nested structure", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"output": [[1, 2], [2, 4], [3, 6]]}, ( - f"Expected nested output [[1, 2], [2, 4], [3, 6]], got {result.actual_outputs}" - ) - - -@skip_if_database_unavailable() -def test_iteration_flatten_output_comparison(): - """ - Run both flatten_output configurations in parallel to verify the difference. - """ - runner = TableTestRunner() - - test_cases = [ - WorkflowTestCase( - fixture_path="iteration_flatten_output_enabled_workflow", - inputs={}, - expected_outputs={"output": [1, 2, 2, 4, 3, 6]}, - description="flatten_output=True: Flattened output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - WorkflowTestCase( - fixture_path="iteration_flatten_output_disabled_workflow", - inputs={}, - expected_outputs={"output": [[1, 2], [2, 4], [3, 6]]}, - description="flatten_output=False: Nested output", - use_auto_mock=True, # Use auto-mock to avoid sandbox service - mock_config=_create_iteration_mock_config(), - ), - ] - - suite_result = runner.run_table_tests(test_cases, parallel=True) - - # Assert all tests passed - assert suite_result.passed_tests == 2, f"Expected 2 passed tests, got {suite_result.passed_tests}" - assert suite_result.failed_tests == 0, f"Expected 0 failed tests, got {suite_result.failed_tests}" - assert suite_result.success_rate == 100.0, f"Expected 100% success rate, got {suite_result.success_rate}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py deleted file mode 100644 index 733fd53bc8..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_contains_answer.py +++ /dev/null @@ -1,85 +0,0 @@ -""" -Test case for loop with inner answer output error scenario. - -This test validates the behavior of a loop containing an answer node -inside the loop that may produce output errors. -""" - -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_contains_answer(): - """ - Test loop with inner answer node that may have output errors. - - The fixture implements a loop that: - 1. Iterates 4 times (index 0-3) - 2. Contains an inner answer node that outputs index and item values - 3. Has a break condition when index equals 4 - 4. Tests error handling for answer nodes within loops - """ - fixture_name = "loop_contains_answer" - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query="1", - expected_outputs={"answer": "1\n2\n1 + 2"}, - expected_event_sequence=[ - # Graph start - GraphRunStartedEvent, - # Start - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop start - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop next - NodeRunLoopNextEvent, - # Variable assigner - NodeRunStartedEvent, - NodeRunStreamChunkEvent, # 2 - NodeRunStreamChunkEvent, # \n - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Loop end - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # 1 - NodeRunStreamChunkEvent, # + - NodeRunStreamChunkEvent, # 2 - NodeRunSucceededEvent, - # Answer - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Graph end - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py deleted file mode 100644 index ad8d777ea6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_node.py +++ /dev/null @@ -1,41 +0,0 @@ -""" -Test cases for the Loop node functionality using TableTestRunner. - -This module tests the loop node's ability to: -1. Execute iterations with loop variables -2. Handle break conditions correctly -3. Update and propagate loop variables between iterations -4. Output the final loop variable value -""" - -from tests.unit_tests.core.workflow.graph_engine.test_table_runner import ( - TableTestRunner, - WorkflowTestCase, -) - - -def test_loop_with_break_condition(): - """ - Test loop node with break condition. - - The increment_loop_with_break_condition_workflow.yml fixture implements a loop that: - 1. Starts with num=1 - 2. Increments num by 1 each iteration - 3. Breaks when num >= 5 - 4. Should output {"num": 5} - """ - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="increment_loop_with_break_condition_workflow", - inputs={}, # No inputs needed for this test - expected_outputs={"num": 5}, - description="Loop with break condition when num >= 5", - ) - - result = runner.run_test_case(test_case) - - # Assert the test passed - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs is not None, "Should have outputs" - assert result.actual_outputs == {"num": 5}, f"Expected {{'num': 5}}, got {result.actual_outputs}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py b/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py deleted file mode 100644 index 6ff2722f78..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_loop_with_tool.py +++ /dev/null @@ -1,67 +0,0 @@ -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunLoopNextEvent, - NodeRunLoopStartedEvent, - NodeRunLoopSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_loop_with_tool(): - fixture_name = "search_dify_from_2023_to_2025" - mock_config = ( - MockConfigBuilder() - .with_tool_response( - { - "text": "mocked search result", - } - ) - .build() - ) - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - expected_outputs={ - "answer": """- mocked search result -- mocked search result""" - }, - expected_event_sequence=[ - GraphRunStartedEvent, - # START - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LOOP START - NodeRunStartedEvent, - NodeRunLoopStartedEvent, - # 2023 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunLoopNextEvent, - # 2024 - NodeRunStartedEvent, - NodeRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, - # LOOP END - NodeRunLoopSucceededEvent, - NodeRunStreamChunkEvent, # loop.res - NodeRunSucceededEvent, - # ANSWER - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py deleted file mode 100644 index c511548749..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_example.py +++ /dev/null @@ -1,281 +0,0 @@ -""" -Example demonstrating the auto-mock system for testing workflows. - -This example shows how to test workflows with third-party service nodes -without making actual API calls. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def example_test_llm_workflow(): - """ - Example: Testing a workflow with an LLM node. - - This demonstrates how to test a workflow that uses an LLM service - without making actual API calls to OpenAI, Anthropic, etc. - """ - print("\n=== Example: Testing LLM Workflow ===\n") - - # Initialize the test runner - runner = TableTestRunner() - - # Configure mock responses - mock_config = MockConfigBuilder().with_llm_response("I'm a helpful AI assistant. How can I help you today?").build() - - # Define the test case - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Hello, AI!"}, - expected_outputs={"answer": "I'm a helpful AI assistant. How can I help you today?"}, - description="Testing LLM workflow with mocked response", - use_auto_mock=True, # Enable auto-mocking - mock_config=mock_config, - ) - - # Run the test - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Test passed!") - print(f" Input: {test_case.inputs['query']}") - print(f" Output: {result.actual_outputs['answer']}") - print(f" Execution time: {result.execution_time:.2f}s") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_with_custom_outputs(): - """ - Example: Testing with custom outputs for specific nodes. - - This shows how to provide different mock outputs for specific node IDs, - useful when testing complex workflows with multiple LLM/tool nodes. - """ - print("\n=== Example: Custom Node Outputs ===\n") - - runner = TableTestRunner() - - # Configure mock with specific outputs for different nodes - mock_config = MockConfigBuilder().build() - - # Set custom output for a specific LLM node - mock_config.set_node_outputs( - "llm_node", - { - "text": "This is a custom response for the specific LLM node", - "usage": { - "prompt_tokens": 50, - "completion_tokens": 20, - "total_tokens": 70, - }, - "finish_reason": "stop", - }, - ) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Tell me about custom outputs"}, - expected_outputs={"answer": "This is a custom response for the specific LLM node"}, - description="Testing with custom node outputs", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Test with custom outputs passed!") - print(f" Custom output: {result.actual_outputs['answer']}") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_http_and_tool_workflow(): - """ - Example: Testing a workflow with HTTP request and tool nodes. - - This demonstrates mocking external HTTP calls and tool executions. - """ - print("\n=== Example: HTTP and Tool Workflow ===\n") - - runner = TableTestRunner() - - # Configure mocks for HTTP and Tool nodes - mock_config = MockConfigBuilder().build() - - # Mock HTTP response - mock_config.set_node_outputs( - "http_node", - { - "status_code": 200, - "body": '{"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}', - "headers": {"content-type": "application/json"}, - }, - ) - - # Mock tool response (e.g., JSON parser) - mock_config.set_node_outputs( - "tool_node", - { - "result": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - ) - - test_case = WorkflowTestCase( - fixture_path="http-tool-workflow", - inputs={"url": "https://api.example.com/users"}, - expected_outputs={ - "status_code": 200, - "parsed_data": {"users": [{"id": 1, "name": "Alice"}, {"id": 2, "name": "Bob"}]}, - }, - description="Testing HTTP and Tool workflow", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ HTTP and Tool workflow test passed!") - print(f" HTTP Status: {result.actual_outputs['status_code']}") - print(f" Parsed Data: {result.actual_outputs['parsed_data']}") - else: - print(f"❌ Test failed: {result.error}") - - return result.success - - -def example_test_error_simulation(): - """ - Example: Simulating errors in specific nodes. - - This shows how to test error handling in workflows by simulating - failures in specific nodes. - """ - print("\n=== Example: Error Simulation ===\n") - - runner = TableTestRunner() - - # Configure mock to simulate an error - mock_config = MockConfigBuilder().build() - mock_config.set_node_error("llm_node", "API rate limit exceeded") - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "This will fail"}, - expected_outputs={}, # We expect failure - description="Testing error handling", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if not result.success: - print("✅ Error simulation worked as expected!") - print(f" Simulated error: {result.error}") - else: - print("❌ Expected failure but test succeeded") - - return not result.success # Success means we got the expected error - - -def example_test_with_delays(): - """ - Example: Testing with simulated execution delays. - - This demonstrates how to simulate realistic execution times - for performance testing. - """ - print("\n=== Example: Simulated Delays ===\n") - - runner = TableTestRunner() - - # Configure mock with delays - mock_config = ( - MockConfigBuilder() - .with_delays(True) # Enable delay simulation - .with_llm_response("Response after delay") - .build() - ) - - # Add specific delay for the LLM node - from .test_mock_config import NodeMockConfig - - node_config = NodeMockConfig( - node_id="llm_node", - outputs={"text": "Response after delay"}, - delay=0.5, # 500ms delay - ) - mock_config.set_node_config("llm_node", node_config) - - test_case = WorkflowTestCase( - fixture_path="llm-simple", - inputs={"query": "Test with delay"}, - expected_outputs={"answer": "Response after delay"}, - description="Testing with simulated delays", - use_auto_mock=True, - mock_config=mock_config, - ) - - result = runner.run_test_case(test_case) - - if result.success: - print("✅ Delay simulation test passed!") - print(f" Execution time: {result.execution_time:.2f}s") - print(" (Should be >= 0.5s due to simulated delay)") - else: - print(f"❌ Test failed: {result.error}") - - return result.success and result.execution_time >= 0.5 - - -def run_all_examples(): - """Run all example tests.""" - print("\n" + "=" * 50) - print("AUTO-MOCK SYSTEM EXAMPLES") - print("=" * 50) - - examples = [ - example_test_llm_workflow, - example_test_with_custom_outputs, - example_test_http_and_tool_workflow, - example_test_error_simulation, - example_test_with_delays, - ] - - results = [] - for example in examples: - try: - results.append(example()) - except Exception as e: - print(f"\n❌ Example failed with exception: {e}") - results.append(False) - - print("\n" + "=" * 50) - print("SUMMARY") - print("=" * 50) - - passed = sum(results) - total = len(results) - print(f"\n✅ Passed: {passed}/{total}") - - if passed == total: - print("\n🎉 All examples passed successfully!") - else: - print(f"\n⚠️ {total - passed} example(s) failed") - - return passed == total - - -if __name__ == "__main__": - import sys - - success = run_all_examples() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 93010eea54..88989db856 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -7,10 +7,11 @@ requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request from typing import TYPE_CHECKING, Any +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node + from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node from .test_mock_nodes import ( MockAgentNode, @@ -28,8 +29,8 @@ from .test_mock_nodes import ( ) if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -111,7 +112,7 @@ class MockNodeFactory(DifyNodeFactory): mock_config=self.mock_config, http_request_config=self._http_request_config, http_client=self._http_request_http_client, - tool_file_manager_factory=self._http_request_tool_file_manager_factory, + tool_file_manager_factory=self._bound_tool_file_manager_factory, file_manager=self._http_request_file_manager, ) elif node_type in { diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py deleted file mode 100644 index 3e4247f33f..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_iteration_simple.py +++ /dev/null @@ -1,199 +0,0 @@ -""" -Simple test to verify MockNodeFactory works with iteration nodes. -""" - -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfigBuilder -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_factory_registers_iteration_node(): - """Test that MockNodeFactory has iteration node registered.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create a MockNodeFactory instance - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Check that iteration node is registered - assert BuiltinNodeTypes.ITERATION in factory._mock_node_types - print("✓ Iteration node is registered in MockNodeFactory") - - # Check that loop node is registered - assert BuiltinNodeTypes.LOOP in factory._mock_node_types - print("✓ Loop node is registered in MockNodeFactory") - - # Check the class types - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode, MockLoopNode - - assert factory._mock_node_types[BuiltinNodeTypes.ITERATION] == MockIterationNode - print("✓ Iteration node maps to MockIterationNode class") - - assert factory._mock_node_types[BuiltinNodeTypes.LOOP] == MockLoopNode - print("✓ Loop node maps to MockLoopNode class") - - -def test_mock_iteration_node_preserves_config(): - """Test that MockIterationNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockIterationNode - - # Create mock config - mock_config = MockConfigBuilder().with_llm_response("Test response").build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock iteration node - node_config = { - "id": "iter1", - "data": { - "type": "iteration", - "title": "Test", - "iterator_selector": ["start", "items"], - "output_selector": ["node", "text"], - "start_node_id": "node1", - }, - } - - mock_node = MockIterationNode( - id="iter1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("✓ MockIterationNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockIterationNode._create_graph_engine != MockIterationNode.__bases__[1]._create_graph_engine - print("✓ MockIterationNode overrides _create_graph_engine method") - - -def test_mock_loop_node_preserves_config(): - """Test that MockLoopNode preserves mock configuration.""" - - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockLoopNode - - # Create mock config - mock_config = MockConfigBuilder().with_http_response({"status": 200}).build() - - # Create minimal graph init params - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={"nodes": [], "edges": []}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - - # Create minimal runtime state - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - - # Create mock loop node - node_config = { - "id": "loop1", - "data": { - "type": "loop", - "title": "Test", - "loop_count": 3, - "start_node_id": "node1", - "loop_variables": [], - "outputs": {}, - "break_conditions": [], - "logical_operator": "and", - }, - } - - mock_node = MockLoopNode( - id="loop1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Verify the mock config is preserved - assert mock_node.mock_config == mock_config - print("✓ MockLoopNode preserves mock configuration") - - # Check that _create_graph_engine method exists and is overridden - assert hasattr(mock_node, "_create_graph_engine") - assert MockLoopNode._create_graph_engine != MockLoopNode.__bases__[1]._create_graph_engine - print("✓ MockLoopNode overrides _create_graph_engine method") - - -if __name__ == "__main__": - test_mock_factory_registers_iteration_node() - test_mock_iteration_node_preserves_config() - test_mock_loop_node_preserves_config() - print("\n✅ All tests passed! MockNodeFactory now supports iteration and loop nodes.") diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py index 454263bef9..8b7fbd1b30 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes.py @@ -10,30 +10,31 @@ from collections.abc import Generator, Mapping from typing import TYPE_CHECKING, Any, Optional from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.code import CodeNode +from graphon.nodes.document_extractor import DocumentExtractorNode +from graphon.nodes.http_request import HttpRequestNode +from graphon.nodes.llm import LLMNode +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.nodes.parameter_extractor import ParameterExtractorNode +from graphon.nodes.protocols import FileReferenceFactoryProtocol, HttpClientProtocol, ToolFileManagerProtocol +from graphon.nodes.question_classifier import QuestionClassifierNode +from graphon.nodes.template_transform import TemplateTransformNode +from graphon.nodes.tool import ToolNode +from graphon.template_rendering import Jinja2TemplateRenderer, TemplateRenderError + from core.model_manager import ModelInstance +from core.workflow.node_runtime import DifyToolNodeRuntime from core.workflow.nodes.agent import AgentNode from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent -from dify_graph.nodes.code import CodeNode -from dify_graph.nodes.document_extractor import DocumentExtractorNode -from dify_graph.nodes.http_request import HttpRequestNode -from dify_graph.nodes.llm import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.parameter_extractor import ParameterExtractorNode -from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol -from dify_graph.nodes.question_classifier import QuestionClassifierNode -from dify_graph.nodes.template_transform import TemplateTransformNode -from dify_graph.nodes.template_transform.template_renderer import ( - Jinja2TemplateRenderer, - TemplateRenderError, -) -from dify_graph.nodes.tool import ToolNode if TYPE_CHECKING: - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.runtime import GraphRuntimeState from .test_mock_config import MockConfig @@ -66,20 +67,26 @@ class MockNodeMixin: kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider)) kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory)) kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance)) + kwargs.setdefault("prompt_message_serializer", MagicMock(spec=PromptMessageSerializerProtocol)) # LLM-like nodes now require an http_client; provide a mock by default for tests. kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol)) - if isinstance(self, (LLMNode, QuestionClassifierNode)): - kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer)) + + if isinstance(self, (LLMNode, QuestionClassifierNode)): + kwargs.setdefault("llm_file_saver", MagicMock(spec=LLMFileSaver)) + + if isinstance(self, HttpRequestNode): + kwargs.setdefault("file_reference_factory", MagicMock(spec=FileReferenceFactoryProtocol)) # Ensure TemplateTransformNode receives a renderer now required by constructor if isinstance(self, TemplateTransformNode): - kwargs.setdefault("template_renderer", _TestJinja2Renderer()) + kwargs.setdefault("jinja2_template_renderer", _TestJinja2Renderer()) # Provide default tool_file_manager_factory for ToolNode subclasses - from dify_graph.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles + from graphon.nodes.tool import ToolNode as _ToolNode # local import to avoid cycles if isinstance(self, _ToolNode): kwargs.setdefault("tool_file_manager_factory", MagicMock(spec=ToolFileManagerProtocol)) + kwargs.setdefault("runtime", DifyToolNodeRuntime(graph_init_params.run_context)) if isinstance(self, AgentNode): presentation_provider = MagicMock() @@ -596,8 +603,8 @@ class MockDocumentExtractorNode(MockNodeMixin, DocumentExtractorNode): ) -from dify_graph.nodes.iteration import IterationNode -from dify_graph.nodes.loop import LoopNode +from graphon.nodes.iteration import IterationNode +from graphon.nodes.loop import LoopNode class MockIterationNode(MockNodeMixin, IterationNode): @@ -611,11 +618,11 @@ class MockIterationNode(MockNodeMixin, IterationNode): def _create_graph_engine(self, index: int, item: Any): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.graph import Graph + from graphon.graph_engine import GraphEngine, GraphEngineConfig + from graphon.graph_engine.command_channels import InMemoryChannel + from graphon.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory @@ -656,7 +663,7 @@ class MockIterationNode(MockNodeMixin, IterationNode): ) if not iteration_graph: - from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError + from graphon.nodes.iteration.exc import IterationGraphNotFoundError raise IterationGraphNotFoundError("iteration graph not found") @@ -683,11 +690,11 @@ class MockLoopNode(MockNodeMixin, LoopNode): def _create_graph_engine(self, start_at, root_node_id: str): """Create a graph engine with MockNodeFactory instead of DifyNodeFactory.""" # Import dependencies - from dify_graph.entities import GraphInitParams - from dify_graph.graph import Graph - from dify_graph.graph_engine import GraphEngine, GraphEngineConfig - from dify_graph.graph_engine.command_channels import InMemoryChannel - from dify_graph.runtime import GraphRuntimeState + from graphon.entities import GraphInitParams + from graphon.graph import Graph + from graphon.graph_engine import GraphEngine, GraphEngineConfig + from graphon.graph_engine.command_channels import InMemoryChannel + from graphon.runtime import GraphRuntimeState # Import our MockNodeFactory instead of DifyNodeFactory from .test_mock_factory import MockNodeFactory diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py deleted file mode 100644 index a8398e8f79..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_nodes_template_code.py +++ /dev/null @@ -1,670 +0,0 @@ -""" -Test cases for Mock Template Transform and Code nodes. - -This module tests the functionality of MockTemplateTransformNode and MockCodeNode -to ensure they work correctly with the TableTestRunner. -""" - -from configs import dify_config -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.nodes.code.limits import CodeNodeLimits -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory -from tests.unit_tests.core.workflow.graph_engine.test_mock_nodes import MockCodeNode, MockTemplateTransformNode - -DEFAULT_CODE_LIMITS = CodeNodeLimits( - max_string_length=dify_config.CODE_MAX_STRING_LENGTH, - max_number=dify_config.CODE_MAX_NUMBER, - min_number=dify_config.CODE_MIN_NUMBER, - max_precision=dify_config.CODE_MAX_PRECISION, - max_depth=dify_config.CODE_MAX_DEPTH, - max_number_array_length=dify_config.CODE_MAX_NUMBER_ARRAY_LENGTH, - max_string_array_length=dify_config.CODE_MAX_STRING_ARRAY_LENGTH, - max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH, -) - - -class _NoopCodeExecutor: - def execute(self, *, language: object, code: str, inputs: dict[str, object]) -> dict[str, object]: - _ = (language, code, inputs) - return {} - - def is_execution_error(self, error: Exception) -> bool: - _ = error - return False - - -class TestMockTemplateTransformNode: - """Test cases for MockTemplateTransformNode.""" - - def test_mock_template_transform_node_default_output(self): - """Test that MockTemplateTransformNode processes templates with Jinja2.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - # The template "Hello {{ name }}" with no name variable renders as "Hello " - assert result.outputs["output"] == "Hello " - - def test_mock_template_transform_node_custom_output(self): - """Test that MockTemplateTransformNode returns custom configured output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder().with_node_output("template_node_1", {"output": "Custom template output"}).build() - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Custom template output" - - def test_mock_template_transform_node_error_simulation(self): - """Test that MockTemplateTransformNode can simulate errors.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with error - mock_config = MockConfigBuilder().with_node_error("template_node_1", "Simulated template error").build() - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.FAILED - assert result.error == "Simulated template error" - - def test_mock_template_transform_node_with_variables(self): - """Test that MockTemplateTransformNode processes templates with variables.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - from dify_graph.variables import StringVariable - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - # Add a variable to the pool - variable_pool.add(["test", "name"], StringVariable(name="name", value="World", selector=["test", "name"])) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with a variable - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template Transform", - "variables": [{"variable": "name", "value_selector": ["test", "name"]}], - "template": "Hello {{ name }}!", - }, - } - - # Create mock node - mock_node = MockTemplateTransformNode( - id="template_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "output" in result.outputs - assert result.outputs["output"] == "Hello World!" - - -class TestMockCodeNode: - """Test cases for MockCodeNode.""" - - def test_mock_code_node_default_output(self): - """Test that MockCodeNode returns default output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "mocked code execution result" - - def test_mock_code_node_with_output_schema(self): - """Test that MockCodeNode generates outputs based on schema.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config - mock_config = MockConfig() - - # Create node config with output schema - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "name = 'test'\ncount = 42\nitems = ['a', 'b']", - "outputs": { - "name": {"type": "string"}, - "count": {"type": "number"}, - "items": {"type": "array[string]"}, - }, - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "name" in result.outputs - assert result.outputs["name"] == "mocked_name" - assert "count" in result.outputs - assert result.outputs["count"] == 42 - assert "items" in result.outputs - assert result.outputs["items"] == ["item1", "item2"] - - def test_mock_code_node_custom_output(self): - """Test that MockCodeNode returns custom configured output.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create mock config with custom output - mock_config = ( - MockConfigBuilder() - .with_node_output("code_node_1", {"result": "Custom code result", "status": "success"}) - .build() - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 'test'", - "outputs": {}, # Empty outputs for default case - }, - } - - # Create mock node - mock_node = MockCodeNode( - id="code_node_1", - config=node_config, - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=mock_config, - code_executor=_NoopCodeExecutor(), - code_limits=DEFAULT_CODE_LIMITS, - ) - - # Run the node - result = mock_node._run() - - # Verify results - assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED - assert "result" in result.outputs - assert result.outputs["result"] == "Custom code result" - assert "status" in result.outputs - assert result.outputs["status"] == "success" - - -class TestMockNodeFactory: - """Test cases for MockNodeFactory with new node types.""" - - def test_code_and_template_nodes_mocked_by_default(self): - """Test that CODE and TEMPLATE_TRANSFORM nodes are mocked by default (they require SSRF proxy).""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Verify that CODE and TEMPLATE_TRANSFORM ARE mocked by default (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Verify that other third-party service nodes ARE also mocked by default - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - - def test_factory_creates_mock_template_transform_node(self): - """Test that MockNodeFactory creates MockTemplateTransformNode for template-transform type.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "template_node_1", - "data": { - "type": "template-transform", - "title": "Test Template", - "variables": [], - "template": "Hello {{ name }}", - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockTemplateTransformNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - def test_factory_creates_mock_code_node(self): - """Test that MockNodeFactory creates MockCodeNode for code type.""" - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - # Create test parameters - graph_init_params = GraphInitParams( - workflow_id="test_workflow", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test_tenant", - "app_id": "test_app", - "user_id": "test_user", - "user_from": "account", - "invoke_from": "debugger", - } - }, - call_depth=0, - ) - - variable_pool = VariablePool( - system_variables={}, - user_inputs={}, - ) - - graph_runtime_state = GraphRuntimeState( - variable_pool=variable_pool, - start_at=0, - ) - - # Create factory - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - ) - - # Create node config - node_config = { - "id": "code_node_1", - "data": { - "type": "code", - "title": "Test Code", - "variables": [], - "code_language": "python3", - "code": "result = 42", - "outputs": {}, # Required field for CodeNodeData - }, - } - - # Create node through factory - node = factory.create_node(node_config) - - # Verify the correct mock type was created - assert isinstance(node, MockCodeNode) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py deleted file mode 100644 index 5b35b3310a..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_simple.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -Simple test to validate the auto-mock system without external dependencies. -""" - -import sys - -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes -from tests.unit_tests.core.workflow.graph_engine.test_mock_config import MockConfig, MockConfigBuilder, NodeMockConfig -from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNodeFactory - - -def test_mock_config_builder(): - """Test the MockConfigBuilder fluent interface.""" - print("Testing MockConfigBuilder...") - - config = ( - MockConfigBuilder() - .with_llm_response("LLM response") - .with_agent_response("Agent response") - .with_tool_response({"tool": "output"}) - .with_retrieval_response("Retrieval content") - .with_http_response({"status_code": 201, "body": "created"}) - .with_node_output("node1", {"output": "value"}) - .with_node_error("node2", "error message") - .with_delays(True) - .build() - ) - - assert config.default_llm_response == "LLM response" - assert config.default_agent_response == "Agent response" - assert config.default_tool_response == {"tool": "output"} - assert config.default_retrieval_response == "Retrieval content" - assert config.default_http_response == {"status_code": 201, "body": "created"} - assert config.simulate_delays is True - - node1_config = config.get_node_config("node1") - assert node1_config is not None - assert node1_config.outputs == {"output": "value"} - - node2_config = config.get_node_config("node2") - assert node2_config is not None - assert node2_config.error == "error message" - - print("✓ MockConfigBuilder test passed") - - -def test_mock_config_operations(): - """Test MockConfig operations.""" - print("Testing MockConfig operations...") - - config = MockConfig() - - # Test setting node outputs - config.set_node_outputs("test_node", {"result": "test_value"}) - node_config = config.get_node_config("test_node") - assert node_config is not None - assert node_config.outputs == {"result": "test_value"} - - # Test setting node error - config.set_node_error("error_node", "Test error") - error_config = config.get_node_config("error_node") - assert error_config is not None - assert error_config.error == "Test error" - - # Test default configs by node type - config.set_default_config(BuiltinNodeTypes.LLM, {"temperature": 0.7}) - llm_config = config.get_default_config(BuiltinNodeTypes.LLM) - assert llm_config == {"temperature": 0.7} - - print("✓ MockConfig operations test passed") - - -def test_node_mock_config(): - """Test NodeMockConfig.""" - print("Testing NodeMockConfig...") - - # Test with custom handler - def custom_handler(node): - return {"custom": "output"} - - node_config = NodeMockConfig( - node_id="test_node", outputs={"text": "test"}, error=None, delay=0.5, custom_handler=custom_handler - ) - - assert node_config.node_id == "test_node" - assert node_config.outputs == {"text": "test"} - assert node_config.delay == 0.5 - assert node_config.custom_handler is not None - - # Test custom handler - result = node_config.custom_handler(None) - assert result == {"custom": "output"} - - print("✓ NodeMockConfig test passed") - - -def test_mock_factory_detection(): - """Test MockNodeFactory node type detection.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory detection...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # Test that third-party service nodes are identified for mocking - assert factory.should_mock_node(BuiltinNodeTypes.LLM) - assert factory.should_mock_node(BuiltinNodeTypes.AGENT) - assert factory.should_mock_node(BuiltinNodeTypes.TOOL) - assert factory.should_mock_node(BuiltinNodeTypes.KNOWLEDGE_RETRIEVAL) - assert factory.should_mock_node(BuiltinNodeTypes.HTTP_REQUEST) - assert factory.should_mock_node(BuiltinNodeTypes.PARAMETER_EXTRACTOR) - assert factory.should_mock_node(BuiltinNodeTypes.DOCUMENT_EXTRACTOR) - - # Test that CODE and TEMPLATE_TRANSFORM are mocked (they require SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.CODE) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Test that non-service nodes are not mocked - assert not factory.should_mock_node(BuiltinNodeTypes.START) - assert not factory.should_mock_node(BuiltinNodeTypes.END) - assert not factory.should_mock_node(BuiltinNodeTypes.IF_ELSE) - assert not factory.should_mock_node(BuiltinNodeTypes.VARIABLE_AGGREGATOR) - - print("✓ MockNodeFactory detection test passed") - - -def test_mock_factory_registration(): - """Test registering and unregistering mock node types.""" - from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom - from dify_graph.entities import GraphInitParams - from dify_graph.runtime import GraphRuntimeState, VariablePool - - print("Testing MockNodeFactory registration...") - - graph_init_params = GraphInitParams( - workflow_id="test", - graph_config={}, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "test", - "app_id": "test", - "user_id": "test", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.SERVICE_API, - } - }, - call_depth=0, - ) - graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}), - start_at=0, - total_tokens=0, - node_run_steps=0, - ) - factory = MockNodeFactory( - graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, - mock_config=None, - ) - - # TEMPLATE_TRANSFORM is mocked by default (requires SSRF proxy) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Unregister mock - factory.unregister_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - assert not factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - # Register custom mock (using a dummy class for testing) - class DummyMockNode: - pass - - factory.register_mock_node_type(BuiltinNodeTypes.TEMPLATE_TRANSFORM, DummyMockNode) - assert factory.should_mock_node(BuiltinNodeTypes.TEMPLATE_TRANSFORM) - - print("✓ MockNodeFactory registration test passed") - - -def run_all_tests(): - """Run all tests.""" - print("\n=== Running Auto-Mock System Tests ===\n") - - try: - test_mock_config_builder() - test_mock_config_operations() - test_node_mock_config() - test_mock_factory_detection() - test_mock_factory_registration() - - print("\n=== All tests passed! ✅ ===\n") - return True - except AssertionError as e: - print(f"\n❌ Test failed: {e}") - return False - except Exception as e: - print(f"\n❌ Unexpected error: {e}") - import traceback - - traceback.print_exc() - return False - - -if __name__ == "__main__": - success = run_all_tests() - sys.exit(0 if success else 1) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py index e681b39cc7..8311a1e847 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_join_resume.py @@ -4,32 +4,33 @@ from dataclasses import dataclass from datetime import datetime, timedelta from typing import Any, Protocol -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( +from graphon.entities import WorkflowStartReason +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunPausedEvent, GraphRunStartedEvent, GraphRunSucceededEvent, NodeRunSucceededEvent, ) -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( +from graphon.nodes.base.entities import OutputVariableEntity +from graphon.nodes.end.end_node import EndNode +from graphon.nodes.end.entities import EndNodeData +from graphon.nodes.human_input.entities import HumanInputNodeData, UserAction +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState, VariablePool + +from core.repositories.human_input_repository import ( FormCreateParams, HumanInputFormEntity, HumanInputFormRepository, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables from libs.datetime_utils import naive_utc_now from tests.workflow_test_utils import build_test_graph_init_params @@ -67,7 +68,7 @@ class StaticForm(HumanInputFormEntity): return self.form_id @property - def web_app_token(self) -> str | None: + def submission_token(self) -> str | None: return "token" @property @@ -103,7 +104,7 @@ class StaticRepo(HumanInputFormRepository): def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: self._forms_by_node_id = dict(forms_by_node_id) - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: + def get_form(self, node_id: str) -> HumanInputFormEntity | None: return self._forms_by_node_id.get(node_id) def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: @@ -112,7 +113,7 @@ class StaticRepo(HumanInputFormRepository): def _build_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -159,6 +160,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) human_b_config = {"id": "human_b", "data": human_data.model_dump()} @@ -168,6 +170,7 @@ def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepositor graph_init_params=graph_init_params, graph_runtime_state=runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) end_data = EndNodeData( diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py deleted file mode 100644 index 60167c0441..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_human_input_pause_missing_finish.py +++ /dev/null @@ -1,333 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunPauseRequestedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, forms_by_node_id: Mapping[str, HumanInputFormEntity]) -> None: - self._forms_by_node_id = dict(forms_by_node_id) - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - return self._forms_by_node_id.get(node_id) - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in resume scenario") - - -class DelayedHumanInputNode(HumanInputNode): - def __init__(self, delay_seconds: float, **kwargs: Any) -> None: - super().__init__(**kwargs) - self._delay_seconds = delay_seconds - - def _run(self): - if self._delay_seconds > 0: - time.sleep(self._delay_seconds) - yield from super()._run() - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Human input required", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - - human_a_config = {"id": "human_a", "data": human_data.model_dump()} - human_a = HumanInputNode( - id=human_a_config["id"], - config=human_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - human_b_config = {"id": "human_b", "data": human_data.model_dump()} - human_b = DelayedHumanInputNode( - id=human_b_config["id"], - config=human_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - delay_seconds=0.2, - ) - - llm_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_config = {"id": "llm_a", "data": llm_data.model_dump()} - llm_a = MockLLMNode( - id=llm_config["id"], - config=llm_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_a, from_node_id="start") - .add_node(human_b, from_node_id="start") - .add_node(llm_a, from_node_id="human_a", source_handle="approve") - .build() - ) - - -def test_parallel_human_input_pause_preserves_node_finished() -> None: - runtime_state = _build_runtime_state() - - runtime_state.graph_execution.start() - runtime_state.register_paused_node("human_a") - runtime_state.register_paused_node("human_b") - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(runtime_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - graph_started = any(isinstance(e, GraphRunStartedEvent) for e in events) - - assert graph_started - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded - - -def test_parallel_human_input_pause_preserves_node_finished_after_snapshot_resume() -> None: - base_state = _build_runtime_state() - base_state.graph_execution.start() - base_state.register_paused_node("human_a") - base_state.register_paused_node("human_b") - snapshot = base_state.dumps() - - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted = StaticForm( - form_id="form-a", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - pending = StaticForm( - form_id="form-b", - rendered="rendered", - is_submitted=False, - action_id=None, - data=None, - status_value=HumanInputFormStatus.WAITING, - ) - repo = StaticRepo({"human_a": submitted, "human_b": pending}) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - - graph = _build_graph(resumed_state, repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - events = list(engine.run()) - - start_event = next(e for e in events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_started = any(isinstance(e, NodeRunStartedEvent) and e.node_id == "llm_a" for e in events) - llm_succeeded = any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in events) - human_b_pause = any(isinstance(e, NodeRunPauseRequestedEvent) and e.node_id == "human_b" for e in events) - graph_paused = any(isinstance(e, GraphRunPausedEvent) for e in events) - - assert graph_paused - assert human_b_pause - assert llm_started - assert llm_succeeded diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py deleted file mode 100644 index b954a4faac..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_parallel_streaming_workflow.py +++ /dev/null @@ -1,286 +0,0 @@ -""" -Test for parallel streaming workflow behavior. - -This test validates that: -- LLM 1 always speaks English -- LLM 2 always speaks Chinese -- 2 LLMs run parallel, but LLM 2 will output before LLM 1 -- All chunks should be sent before Answer Node started -""" - -import time -from unittest.mock import MagicMock, patch -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.model_manager import ModelInstance -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) -from dify_graph.node_events import NodeRunResult, StreamCompletedEvent -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_table_runner import TableTestRunner - - -def create_llm_generator_with_delay(chunks: list[str], delay: float = 0.1): - """Create a generator that simulates LLM streaming output with delay""" - - def llm_generator(self): - for i, chunk in enumerate(chunks): - time.sleep(delay) # Simulate network delay - yield NodeRunStreamChunkEvent( - id=str(uuid4()), - node_id=self.id, - node_type=self.node_type, - selector=[self.id, "text"], - chunk=chunk, - is_final=i == len(chunks) - 1, - ) - - # Complete response - full_text = "".join(chunks) - yield StreamCompletedEvent( - node_run_result=NodeRunResult( - status=WorkflowNodeExecutionStatus.SUCCEEDED, - outputs={"text": full_text}, - ) - ) - - return llm_generator - - -def test_parallel_streaming_workflow(): - """ - Test parallel streaming workflow to verify: - 1. All chunks from LLM 2 are output before LLM 1 - 2. At least one chunk from LLM 2 is output before LLM 1 completes (Success) - 3. At least one chunk from LLM 1 is output before LLM 2 completes (EXPECTED TO FAIL) - 4. All chunks are output before End begins - 5. The final output content matches the order defined in the Answer - - Test setup: - - LLM 1 outputs English (slower) - - LLM 2 outputs Chinese (faster) - - Both run in parallel - - This test is expected to FAIL because chunks are currently buffered - until after node completion instead of streaming during execution. - """ - runner = TableTestRunner() - - # Load the workflow configuration - fixture_data = runner.workflow_runner.load_fixture("multilingual_parallel_llm_streaming_workflow") - workflow_config = fixture_data.get("workflow", {}) - graph_config = workflow_config.get("graph", {}) - - # Create graph initialization parameters - init_params = build_test_graph_init_params( - workflow_id="test_workflow", - graph_config=graph_config, - tenant_id="test_tenant", - app_id="test_app", - user_id="test_user", - user_from=UserFrom.ACCOUNT, - invoke_from=InvokeFrom.WEB_APP, - call_depth=0, - ) - - # Create variable pool with system variables - system_variables = SystemVariable( - user_id="test_user", - app_id="test_app", - workflow_id=init_params.workflow_id, - files=[], - query="Tell me about yourself", # User query - ) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs={}, - ) - - # Create graph runtime state - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - # Create node factory and graph - node_factory = DifyNodeFactory(graph_init_params=init_params, graph_runtime_state=graph_runtime_state) - with patch.object( - DifyNodeFactory, "_build_model_instance_for_llm_node", return_value=MagicMock(spec=ModelInstance), autospec=True - ): - graph = Graph.init( - graph_config=graph_config, - node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), - ) - - # Create the graph engine - engine = GraphEngine( - workflow_id="test_workflow", - graph=graph, - graph_runtime_state=graph_runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig(), - ) - - # Define LLM outputs - llm1_chunks = ["Hello", ", ", "I", " ", "am", " ", "an", " ", "AI", " ", "assistant", "."] # English (slower) - llm2_chunks = ["你好", ",", "我", "是", "AI", "助手", "。"] # Chinese (faster) - - # Create generators with different delays (LLM 2 is faster) - llm1_generator = create_llm_generator_with_delay(llm1_chunks, delay=0.05) # Slower - llm2_generator = create_llm_generator_with_delay(llm2_chunks, delay=0.01) # Faster - - # Track which LLM node is being called - llm_call_order = [] - generators = { - "1754339718571": llm1_generator, # LLM 1 node ID - "1754339725656": llm2_generator, # LLM 2 node ID - } - - def mock_llm_run(self): - llm_call_order.append(self.id) - generator = generators.get(self.id) - if generator: - yield from generator(self) - else: - raise Exception(f"Unexpected LLM node ID: {self.id}") - - # Execute with mocked LLMs - with patch.object(LLMNode, "_run", new=mock_llm_run): - events = list(engine.run()) - - # Check for successful completion - success_events = [e for e in events if isinstance(e, GraphRunSucceededEvent)] - assert len(success_events) > 0, "Workflow should complete successfully" - - # Get all streaming chunk events - stream_chunk_events = [e for e in events if isinstance(e, NodeRunStreamChunkEvent)] - - # Get Answer node start event - answer_start_events = [ - e for e in events if isinstance(e, NodeRunStartedEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_start_events) == 1, f"Expected 1 Answer node start event, got {len(answer_start_events)}" - answer_start_event = answer_start_events[0] - - # Find the index of Answer node start - answer_start_index = events.index(answer_start_event) - - # Collect chunk events by node - llm1_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339718571"] - llm2_chunks_events = [e for e in stream_chunk_events if e.node_id == "1754339725656"] - - # Verify both LLMs produced chunks - assert len(llm1_chunks_events) == len(llm1_chunks), ( - f"Expected {len(llm1_chunks)} chunks from LLM 1, got {len(llm1_chunks_events)}" - ) - assert len(llm2_chunks_events) == len(llm2_chunks), ( - f"Expected {len(llm2_chunks)} chunks from LLM 2, got {len(llm2_chunks_events)}" - ) - - # 1. Verify chunk ordering based on actual implementation - llm1_chunk_indices = [events.index(e) for e in llm1_chunks_events] - llm2_chunk_indices = [events.index(e) for e in llm2_chunks_events] - - # In the current implementation, chunks may be interleaved or in a specific order - # Update this based on actual behavior observed - if llm1_chunk_indices and llm2_chunk_indices: - # Check the actual ordering - if LLM 2 chunks come first (as seen in debug) - assert max(llm2_chunk_indices) < min(llm1_chunk_indices), ( - f"All LLM 2 chunks should be output before LLM 1 chunks. " - f"LLM 2 chunk indices: {llm2_chunk_indices}, LLM 1 chunk indices: {llm1_chunk_indices}" - ) - - # Get indices of all chunk events - chunk_indices = [events.index(e) for e in stream_chunk_events if e in llm1_chunks_events + llm2_chunks_events] - - # 4. Verify all chunks were sent before Answer node started - assert all(idx < answer_start_index for idx in chunk_indices), ( - "All LLM chunks should be sent before Answer node starts" - ) - - # The test has successfully verified: - # 1. Both LLMs run in parallel (they start at the same time) - # 2. LLM 2 (Chinese) outputs all its chunks before LLM 1 (English) due to faster processing - # 3. All LLM chunks are sent before the Answer node starts - - # Get LLM completion events - llm_completed_events = [ - (i, e) - for i, e in enumerate(events) - if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.LLM - ] - - # Check LLM completion order - in the current implementation, LLMs run sequentially - # LLM 1 completes first, then LLM 2 runs and completes - assert len(llm_completed_events) == 2, f"Expected 2 LLM completion events, got {len(llm_completed_events)}" - llm2_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339725656"), None) - llm1_complete_idx = next((i for i, e in llm_completed_events if e.node_id == "1754339718571"), None) - assert llm2_complete_idx is not None, "LLM 2 completion event not found" - assert llm1_complete_idx is not None, "LLM 1 completion event not found" - # In the actual implementation, LLM 1 completes before LLM 2 (sequential execution) - assert llm1_complete_idx < llm2_complete_idx, ( - f"LLM 1 should complete before LLM 2 in sequential execution, but LLM 1 completed at {llm1_complete_idx} " - f"and LLM 2 completed at {llm2_complete_idx}" - ) - - # 2. In sequential execution, LLM 2 chunks appear AFTER LLM 1 completes - if llm2_chunk_indices: - # LLM 1 completes first, then LLM 2 starts streaming - assert min(llm2_chunk_indices) > llm1_complete_idx, ( - f"LLM 2 chunks should appear after LLM 1 completes in sequential execution. " - f"First LLM 2 chunk at index {min(llm2_chunk_indices)}, LLM 1 completed at index {llm1_complete_idx}" - ) - - # 3. In the current implementation, LLM 1 chunks appear after LLM 2 completes - # This is because chunks are buffered and output after both nodes complete - if llm1_chunk_indices and llm2_complete_idx: - # Check if LLM 1 chunks exist and where they appear relative to LLM 2 completion - # In current behavior, LLM 1 chunks typically appear after LLM 2 completes - pass # Skipping this check as the chunk ordering is implementation-dependent - - # CURRENT BEHAVIOR: Chunks are buffered and appear after node completion - # In the sequential execution, LLM 1 completes first without streaming, - # then LLM 2 streams its chunks - assert stream_chunk_events, "Expected streaming events, but got none" - - first_chunk_index = events.index(stream_chunk_events[0]) - llm_success_indices = [i for i, e in llm_completed_events] - - # Current implementation: LLM 1 completes first, then chunks start appearing - # This is the actual behavior we're testing - if llm_success_indices: - # At least one LLM (LLM 1) completes before any chunks appear - assert min(llm_success_indices) < first_chunk_index, ( - f"In current implementation, LLM 1 completes before chunks start streaming. " - f"First chunk at index {first_chunk_index}, LLM 1 completed at index {min(llm_success_indices)}" - ) - - # 5. Verify final output content matches the order defined in Answer node - # According to Answer node configuration: '{{#1754339725656.text#}}{{#1754339718571.text#}}' - # This means LLM 2 output should come first, then LLM 1 output - answer_complete_events = [ - e for e in events if isinstance(e, NodeRunSucceededEvent) and e.node_type == BuiltinNodeTypes.ANSWER - ] - assert len(answer_complete_events) == 1, f"Expected 1 Answer completion event, got {len(answer_complete_events)}" - - answer_outputs = answer_complete_events[0].node_run_result.outputs - expected_answer_text = "你好,我是AI助手。Hello, I am an AI assistant." - - if "answer" in answer_outputs: - actual_answer_text = answer_outputs["answer"] - assert actual_answer_text == expected_answer_text, ( - f"Answer content should match the order defined in Answer node. " - f"Expected: '{expected_answer_text}', Got: '{actual_answer_text}'" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py deleted file mode 100644 index 7328ce443f..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_deferred_ready_nodes.py +++ /dev/null @@ -1,309 +0,0 @@ -import time -from collections.abc import Mapping -from dataclasses import dataclass -from datetime import datetime, timedelta -from typing import Any - -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.config import GraphEngineConfig -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( - GraphRunPausedEvent, - GraphRunStartedEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.llm.entities import ( - ContextConfig, - LLMNodeChatModelMessage, - LLMNodeData, - ModelConfig, - VisionConfig, -) -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - FormCreateParams, - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - -from .test_mock_config import MockConfig, NodeMockConfig -from .test_mock_nodes import MockLLMNode - - -@dataclass -class StaticForm(HumanInputFormEntity): - form_id: str - rendered: str - is_submitted: bool - action_id: str | None = None - data: Mapping[str, Any] | None = None - status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING - expiration: datetime = naive_utc_now() + timedelta(days=1) - - @property - def id(self) -> str: - return self.form_id - - @property - def web_app_token(self) -> str | None: - return "token" - - @property - def recipients(self) -> list: - return [] - - @property - def rendered_content(self) -> str: - return self.rendered - - @property - def selected_action_id(self) -> str | None: - return self.action_id - - @property - def submitted_data(self) -> Mapping[str, Any] | None: - return self.data - - @property - def submitted(self) -> bool: - return self.is_submitted - - @property - def status(self) -> HumanInputFormStatus: - return self.status_value - - @property - def expiration_time(self) -> datetime: - return self.expiration - - -class StaticRepo(HumanInputFormRepository): - def __init__(self, form: HumanInputFormEntity) -> None: - self._form = form - - def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None: - if node_id != "human_pause": - return None - return self._form - - def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: - raise AssertionError("create_form should not be called in this test") - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="exec-1", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _build_graph(runtime_state: GraphRuntimeState, repo: HumanInputFormRepository, mock_config: MockConfig) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - graph_init_params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="debugger", - call_depth=0, - ) - - start_config = {"id": "start", "data": StartNodeData(title="Start", variables=[]).model_dump()} - start_node = StartNode( - id=start_config["id"], - config=start_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - llm_a_data = LLMNodeData( - title="LLM A", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt A", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_a_config = {"id": "llm_a", "data": llm_a_data.model_dump()} - llm_a = MockLLMNode( - id=llm_a_config["id"], - config=llm_a_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - llm_b_data = LLMNodeData( - title="LLM B", - model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}), - prompt_template=[ - LLMNodeChatModelMessage( - text="Prompt B", - role=PromptMessageRole.USER, - edition_type="basic", - ) - ], - context=ContextConfig(enabled=False, variable_selector=None), - vision=VisionConfig(enabled=False), - reasoning_format="tagged", - structured_output_enabled=False, - ) - llm_b_config = {"id": "llm_b", "data": llm_b_data.model_dump()} - llm_b = MockLLMNode( - id=llm_b_config["id"], - config=llm_b_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - mock_config=mock_config, - ) - - human_data = HumanInputNodeData( - title="Human Input", - form_content="Pause here", - inputs=[], - user_actions=[UserAction(id="approve", title="Approve")], - ) - human_config = {"id": "human_pause", "data": human_data.model_dump()} - human_node = HumanInputNode( - id=human_config["id"], - config=human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - form_repository=repo, - ) - - end_human_data = EndNodeData(title="End Human", outputs=[], desc=None) - end_human_config = {"id": "end_human", "data": end_human_data.model_dump()} - end_human = EndNode( - id=end_human_config["id"], - config=end_human_config, - graph_init_params=graph_init_params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(llm_a, from_node_id="start") - .add_node(human_node, from_node_id="start") - .add_node(llm_b, from_node_id="llm_a") - .add_node(end_human, from_node_id="human_pause", source_handle="approve") - .build() - ) - - -def _get_node_started_event(events: list[object], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def test_pause_defers_ready_nodes_until_resume() -> None: - runtime_state = _build_runtime_state() - - paused_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=False, - status_value=HumanInputFormStatus.WAITING, - ) - pause_repo = StaticRepo(paused_form) - - mock_config = MockConfig() - mock_config.simulate_delays = True - mock_config.set_node_config( - "llm_a", - NodeMockConfig(node_id="llm_a", outputs={"text": "LLM A output"}, delay=0.5), - ) - mock_config.set_node_config( - "llm_b", - NodeMockConfig(node_id="llm_b", outputs={"text": "LLM B output"}, delay=0.0), - ) - - graph = _build_graph(runtime_state, pause_repo, mock_config) - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - paused_events = list(engine.run()) - - assert any(isinstance(e, GraphRunPausedEvent) for e in paused_events) - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_a" for e in paused_events) - assert _get_node_started_event(paused_events, "llm_b") is None - - snapshot = runtime_state.dumps() - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - - submitted_form = StaticForm( - form_id="form-pause", - rendered="rendered", - is_submitted=True, - action_id="approve", - data={}, - status_value=HumanInputFormStatus.SUBMITTED, - ) - resume_repo = StaticRepo(submitted_form) - - resumed_graph = _build_graph(resumed_state, resume_repo, mock_config) - resumed_engine = GraphEngine( - workflow_id="workflow", - graph=resumed_graph, - graph_runtime_state=resumed_state, - command_channel=InMemoryChannel(), - config=GraphEngineConfig( - min_workers=2, - max_workers=2, - scale_up_threshold=1, - scale_down_idle_time=30.0, - ), - ) - - resumed_events = list(resumed_engine.run()) - - start_event = next(e for e in resumed_events if isinstance(e, GraphRunStartedEvent)) - assert start_event.reason is WorkflowStartReason.RESUMPTION - - llm_b_started = _get_node_started_event(resumed_events, "llm_b") - assert llm_b_started is not None - assert any(isinstance(e, NodeRunSucceededEvent) and e.node_id == "llm_b" for e in resumed_events) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py b/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py deleted file mode 100644 index 15a7de3c52..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_pause_resume_state.py +++ /dev/null @@ -1,217 +0,0 @@ -import datetime -import time -from typing import Any -from unittest.mock import MagicMock - -from dify_graph.entities.workflow_start_reason import WorkflowStartReason -from dify_graph.graph import Graph -from dify_graph.graph_engine.command_channels.in_memory_channel import InMemoryChannel -from dify_graph.graph_engine.graph_engine import GraphEngine -from dify_graph.graph_events import ( - GraphEngineEvent, - GraphRunPausedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunSucceededEvent, -) -from dify_graph.graph_events.graph import GraphRunStartedEvent -from dify_graph.nodes.base.entities import OutputVariableEntity -from dify_graph.nodes.end.end_node import EndNode -from dify_graph.nodes.end.entities import EndNodeData -from dify_graph.nodes.human_input.entities import HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.repositories.human_input_form_repository import ( - HumanInputFormEntity, - HumanInputFormRepository, -) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from libs.datetime_utils import naive_utc_now -from tests.workflow_test_utils import build_test_graph_init_params - - -def _build_runtime_state() -> GraphRuntimeState: - variable_pool = VariablePool( - system_variables=SystemVariable( - user_id="user", - app_id="app", - workflow_id="workflow", - workflow_execution_id="test-execution-id", - ), - user_inputs={}, - conversation_variables=[], - ) - return GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - -def _mock_form_repository_with_submission(action_id: str) -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = True - form_entity.selected_action_id = action_id - form_entity.submitted_data = {} - form_entity.expiration_time = naive_utc_now() + datetime.timedelta(days=1) - repo.get_form.return_value = form_entity - return repo - - -def _mock_form_repository_without_submission() -> HumanInputFormRepository: - repo = MagicMock(spec=HumanInputFormRepository) - form_entity = MagicMock(spec=HumanInputFormEntity) - form_entity.id = "test-form-id" - form_entity.web_app_token = "test-form-token" - form_entity.recipients = [] - form_entity.rendered_content = "rendered" - form_entity.submitted = False - repo.create_form.return_value = form_entity - repo.get_form.return_value = None - return repo - - -def _build_human_input_graph( - runtime_state: GraphRuntimeState, - form_repository: HumanInputFormRepository, -) -> Graph: - graph_config: dict[str, object] = {"nodes": [], "edges": []} - params = build_test_graph_init_params( - workflow_id="workflow", - graph_config=graph_config, - tenant_id="tenant", - app_id="app", - user_id="user", - user_from="account", - invoke_from="service-api", - call_depth=0, - ) - - start_data = StartNodeData(title="start", variables=[]) - start_node = StartNode( - id="start", - config={"id": "start", "data": start_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - human_data = HumanInputNodeData( - title="human", - form_content="Awaiting human input", - inputs=[], - user_actions=[ - UserAction(id="continue", title="Continue"), - ], - ) - human_node = HumanInputNode( - id="human", - config={"id": "human", "data": human_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - form_repository=form_repository, - ) - - end_data = EndNodeData( - title="end", - outputs=[ - OutputVariableEntity(variable="result", value_selector=["human", "action_id"]), - ], - desc=None, - ) - end_node = EndNode( - id="end", - config={"id": "end", "data": end_data.model_dump()}, - graph_init_params=params, - graph_runtime_state=runtime_state, - ) - - return ( - Graph.new() - .add_root(start_node) - .add_node(human_node) - .add_node(end_node, from_node_id="human", source_handle="continue") - .build() - ) - - -def _run_graph(graph: Graph, runtime_state: GraphRuntimeState) -> list[GraphEngineEvent]: - engine = GraphEngine( - workflow_id="workflow", - graph=graph, - graph_runtime_state=runtime_state, - command_channel=InMemoryChannel(), - ) - return list(engine.run()) - - -def _node_successes(events: list[GraphEngineEvent]) -> list[str]: - return [event.node_id for event in events if isinstance(event, NodeRunSucceededEvent)] - - -def _node_start_event(events: list[GraphEngineEvent], node_id: str) -> NodeRunStartedEvent | None: - for event in events: - if isinstance(event, NodeRunStartedEvent) and event.node_id == node_id: - return event - return None - - -def _segment_value(variable_pool: VariablePool, selector: tuple[str, str]) -> Any: - segment = variable_pool.get(selector) - assert segment is not None - return getattr(segment, "value", segment) - - -def test_engine_resume_restores_state_and_completion(): - # Baseline run without pausing - baseline_state = _build_runtime_state() - baseline_repo = _mock_form_repository_with_submission(action_id="continue") - baseline_graph = _build_human_input_graph(baseline_state, baseline_repo) - baseline_events = _run_graph(baseline_graph, baseline_state) - assert baseline_events - first_paused_event = baseline_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(baseline_events[-1], GraphRunSucceededEvent) - baseline_success_nodes = _node_successes(baseline_events) - - # Run with pause - paused_state = _build_runtime_state() - pause_repo = _mock_form_repository_without_submission() - paused_graph = _build_human_input_graph(paused_state, pause_repo) - paused_events = _run_graph(paused_graph, paused_state) - assert paused_events - first_paused_event = paused_events[0] - assert isinstance(first_paused_event, GraphRunStartedEvent) - assert first_paused_event.reason is WorkflowStartReason.INITIAL - assert isinstance(paused_events[-1], GraphRunPausedEvent) - snapshot = paused_state.dumps() - - # Resume from snapshot - resumed_state = GraphRuntimeState.from_snapshot(snapshot) - resume_repo = _mock_form_repository_with_submission(action_id="continue") - resumed_graph = _build_human_input_graph(resumed_state, resume_repo) - resumed_events = _run_graph(resumed_graph, resumed_state) - assert resumed_events - first_resumed_event = resumed_events[0] - assert isinstance(first_resumed_event, GraphRunStartedEvent) - assert first_resumed_event.reason is WorkflowStartReason.RESUMPTION - assert isinstance(resumed_events[-1], GraphRunSucceededEvent) - - combined_success_nodes = _node_successes(paused_events) + _node_successes(resumed_events) - assert combined_success_nodes == baseline_success_nodes - - paused_human_started = _node_start_event(paused_events, "human") - resumed_human_started = _node_start_event(resumed_events, "human") - assert paused_human_started is not None - assert resumed_human_started is not None - assert paused_human_started.id == resumed_human_started.id - - assert baseline_state.outputs == resumed_state.outputs - assert _segment_value(baseline_state.variable_pool, ("human", "__action_id")) == _segment_value( - resumed_state.variable_pool, ("human", "__action_id") - ) - assert baseline_state.graph_execution.completed - assert resumed_state.graph_execution.completed diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py deleted file mode 100644 index 9c84f42db6..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_redis_stop_integration.py +++ /dev/null @@ -1,268 +0,0 @@ -""" -Unit tests for Redis-based stop functionality in GraphEngine. - -Tests the integration of Redis command channel for stopping workflows -without user permission checks. -""" - -import json -from unittest.mock import MagicMock, Mock, patch - -import pytest -import redis - -from core.app.apps.base_app_queue_manager import AppQueueManager -from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel -from dify_graph.graph_engine.entities.commands import AbortCommand, CommandType, PauseCommand -from dify_graph.graph_engine.manager import GraphEngineManager - - -class TestRedisStopIntegration: - """Test suite for Redis-based workflow stop functionality.""" - - def test_graph_engine_manager_sends_abort_command(self): - """Test that GraphEngineManager correctly sends abort command through Redis.""" - # Setup - task_id = "test-task-123" - expected_channel_key = f"workflow:{task_id}:commands" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - - # Execute - manager.send_stop_command(task_id, reason="Test stop") - - # Verify - mock_redis.pipeline.assert_called_once() - - # Check that rpush was called with correct arguments - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - - # Verify the channel key - assert calls[0][0][0] == expected_channel_key - - # Verify the command data - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.ABORT - assert command_data["reason"] == "Test stop" - - def test_graph_engine_manager_sends_pause_command(self): - """Test that GraphEngineManager correctly sends pause command through Redis.""" - task_id = "test-task-pause-123" - expected_channel_key = f"workflow:{task_id}:commands" - - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - manager = GraphEngineManager(mock_redis) - manager.send_pause_command(task_id, reason="Awaiting resources") - - mock_redis.pipeline.assert_called_once() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == expected_channel_key - - command_json = calls[0][0][1] - command_data = json.loads(command_json) - assert command_data["command_type"] == CommandType.PAUSE.value - assert command_data["reason"] == "Awaiting resources" - - def test_graph_engine_manager_handles_redis_failure_gracefully(self): - """Test that GraphEngineManager handles Redis failures without raising exceptions.""" - task_id = "test-task-456" - - # Mock redis client to raise exception - mock_redis = MagicMock() - mock_redis.pipeline.side_effect = redis.ConnectionError("Redis connection failed") - manager = GraphEngineManager(mock_redis) - - # Should not raise exception - try: - manager.send_stop_command(task_id) - except Exception as e: - pytest.fail(f"GraphEngineManager.send_stop_command raised {e} unexpectedly") - - def test_app_queue_manager_no_user_check(self): - """Test that AppQueueManager.set_stop_flag_no_user_check works without user validation.""" - task_id = "test-task-789" - expected_cache_key = f"generate_task_stopped:{task_id}" - - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute - AppQueueManager.set_stop_flag_no_user_check(task_id) - - # Verify - mock_redis.setex.assert_called_once_with(expected_cache_key, 600, 1) - - def test_app_queue_manager_no_user_check_with_empty_task_id(self): - """Test that AppQueueManager.set_stop_flag_no_user_check handles empty task_id.""" - # Mock redis client - mock_redis = MagicMock() - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute with empty task_id - AppQueueManager.set_stop_flag_no_user_check("") - - # Verify redis was not called - mock_redis.setex.assert_not_called() - - def test_redis_channel_send_abort_command(self): - """Test RedisChannel correctly serializes and sends AbortCommand.""" - # Setup - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Create commands - abort_command = AbortCommand(reason="User requested stop") - pause_command = PauseCommand(reason="User requested pause") - - # Execute - channel.send_command(abort_command) - channel.send_command(pause_command) - - # Verify - mock_redis.pipeline.assert_called() - - # Check rpush was called - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 2 - assert calls[0][0][0] == channel_key - assert calls[1][0][0] == channel_key - - # Verify serialized commands - abort_command_json = calls[0][0][1] - abort_command_data = json.loads(abort_command_json) - assert abort_command_data["command_type"] == CommandType.ABORT.value - assert abort_command_data["reason"] == "User requested stop" - - pause_command_json = calls[1][0][1] - pause_command_data = json.loads(pause_command_json) - assert pause_command_data["command_type"] == CommandType.PAUSE.value - assert pause_command_data["reason"] == "User requested pause" - - # Check expire was set for each - assert mock_pipeline.expire.call_count == 2 - mock_pipeline.expire.assert_any_call(channel_key, 3600) - - def test_redis_channel_fetch_commands(self): - """Test RedisChannel correctly fetches and deserializes commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock command data - abort_command_json = json.dumps( - {"command_type": CommandType.ABORT.value, "reason": "Test abort", "payload": None} - ) - pause_command_json = json.dumps( - {"command_type": CommandType.PAUSE.value, "reason": "Pause requested", "payload": None} - ) - - # Mock pipeline execute to return commands - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [abort_command_json.encode(), pause_command_json.encode()], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Verify - assert len(commands) == 2 - assert isinstance(commands[0], AbortCommand) - assert commands[0].command_type == CommandType.ABORT - assert commands[0].reason == "Test abort" - assert isinstance(commands[1], PauseCommand) - assert commands[1].command_type == CommandType.PAUSE - assert commands[1].reason == "Pause requested" - - # Verify Redis operations - pending_pipe.get.assert_called_once_with(f"{channel_key}:pending") - pending_pipe.delete.assert_called_once_with(f"{channel_key}:pending") - fetch_pipe.lrange.assert_called_once_with(channel_key, 0, -1) - fetch_pipe.delete.assert_called_once_with(channel_key) - assert mock_redis.pipeline.call_count == 2 - - def test_redis_channel_fetch_commands_handles_invalid_json(self): - """Test RedisChannel gracefully handles invalid JSON in commands.""" - # Setup - mock_redis = MagicMock() - pending_pipe = MagicMock() - fetch_pipe = MagicMock() - pending_context = MagicMock() - fetch_context = MagicMock() - pending_context.__enter__.return_value = pending_pipe - pending_context.__exit__.return_value = None - fetch_context.__enter__.return_value = fetch_pipe - fetch_context.__exit__.return_value = None - mock_redis.pipeline.side_effect = [pending_context, fetch_context] - - # Mock invalid command data - pending_pipe.execute.return_value = [b"1", 1] - fetch_pipe.execute.return_value = [ - [b"invalid json", b'{"command_type": "invalid_type"}'], # lrange result - True, # delete result - ] - - channel_key = "workflow:test:commands" - channel = RedisChannel(mock_redis, channel_key) - - # Execute - commands = channel.fetch_commands() - - # Should return empty list due to invalid commands - assert len(commands) == 0 - - def test_dual_stop_mechanism_compatibility(self): - """Test that both stop mechanisms can work together.""" - task_id = "test-task-dual" - - # Mock redis client - mock_redis = MagicMock() - mock_pipeline = MagicMock() - mock_redis.pipeline.return_value.__enter__ = Mock(return_value=mock_pipeline) - mock_redis.pipeline.return_value.__exit__ = Mock(return_value=None) - - with patch("core.app.apps.base_app_queue_manager.redis_client", mock_redis): - # Execute both stop mechanisms - AppQueueManager.set_stop_flag_no_user_check(task_id) - GraphEngineManager(mock_redis).send_stop_command(task_id) - - # Verify legacy stop flag was set - expected_stop_flag_key = f"generate_task_stopped:{task_id}" - mock_redis.setex.assert_called_once_with(expected_stop_flag_key, 600, 1) - - # Verify command was sent through Redis channel - mock_redis.pipeline.assert_called() - calls = mock_pipeline.rpush.call_args_list - assert len(calls) == 1 - assert calls[0][0][0] == f"workflow:{task_id}:commands" diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py b/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py deleted file mode 100644 index cd9d56f683..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_response_session.py +++ /dev/null @@ -1,55 +0,0 @@ -"""Unit tests for response session creation.""" - -from __future__ import annotations - -import pytest - -from dify_graph.enums import BuiltinNodeTypes, NodeExecutionType, NodeState, NodeType -from dify_graph.graph_engine.response_coordinator.session import ResponseSession -from dify_graph.nodes.base.template import Template, TextSegment - - -class DummyResponseNode: - """Minimal response-capable node for session tests.""" - - def __init__(self, *, node_id: str, node_type: NodeType, template: Template) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - self._template = template - - def get_streaming_template(self) -> Template: - return self._template - - -class DummyNodeWithoutStreamingTemplate: - """Minimal node that violates the response-session contract.""" - - def __init__(self, *, node_id: str, node_type: NodeType) -> None: - self.id = node_id - self.node_type = node_type - self.execution_type = NodeExecutionType.RESPONSE - self.state = NodeState.UNKNOWN - - -def test_response_session_from_node_accepts_nodes_outside_previous_allowlist() -> None: - """Session creation depends on the streaming-template contract rather than node type.""" - node = DummyResponseNode( - node_id="llm-node", - node_type=BuiltinNodeTypes.LLM, - template=Template(segments=[TextSegment(text="hello")]), - ) - - session = ResponseSession.from_node(node) - - assert session.node_id == "llm-node" - assert session.template.segments == [TextSegment(text="hello")] - - -def test_response_session_from_node_requires_streaming_template_method() -> None: - """Allowed node types still need to implement the streaming-template contract.""" - node = DummyNodeWithoutStreamingTemplate(node_id="answer-node", node_type=BuiltinNodeTypes.ANSWER) - - with pytest.raises(TypeError, match="get_streaming_template"): - ResponseSession.from_node(node) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py b/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py deleted file mode 100644 index 4f1741d4fb..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_streaming_conversation_variables.py +++ /dev/null @@ -1,77 +0,0 @@ -from dify_graph.graph_events import ( - GraphRunStartedEvent, - GraphRunSucceededEvent, - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, -) - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_streaming_conversation_variables(): - fixture_name = "test_streaming_conversation_variables" - - # The test expects the workflow to output the input query - # Since the workflow assigns sys.query to conversation variable "str" and then answers with it - input_query = "Hello, this is my test query" - - mock_config = MockConfigBuilder().build() - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, # Don't use auto mock since we want to test actual variable assignment - mock_config=mock_config, - query=input_query, # Pass query as the sys.query value - inputs={}, # No additional inputs needed - expected_outputs={"answer": input_query}, # Expecting the input query to be output - expected_event_sequence=[ - GraphRunStartedEvent, - # START node - NodeRunStartedEvent, - NodeRunSucceededEvent, - # Variable Assigner node - NodeRunStartedEvent, - NodeRunStreamChunkEvent, - NodeRunSucceededEvent, - # ANSWER node - NodeRunStartedEvent, - NodeRunSucceededEvent, - GraphRunSucceededEvent, - ], - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - -def test_streaming_conversation_variables_v1_overwrite_waits_for_assignment(): - fixture_name = "test_streaming_conversation_variables_v1_overwrite" - input_query = "overwrite-value" - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=False, - mock_config=MockConfigBuilder().build(), - query=input_query, - inputs={}, - expected_outputs={"answer": f"Current Value Of `conv_var` is:{input_query}"}, - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - assert result.success, f"Test failed: {result.error}" - - events = result.events - conv_var_chunk_events = [ - event - for event in events - if isinstance(event, NodeRunStreamChunkEvent) and tuple(event.selector) == ("conversation", "conv_var") - ] - - assert conv_var_chunk_events, "Expected conversation variable chunk events to be emitted" - assert all(event.chunk == input_query for event in conv_var_chunk_events), ( - "Expected streamed conversation variable value to match the input query" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py index ab8fb346b8..b11f957677 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_table_runner.py @@ -12,29 +12,24 @@ This module provides a robust table-driven testing framework with support for: import logging import time -from collections.abc import Callable, Mapping, Sequence +from collections.abc import Callable, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from dataclasses import dataclass, field from functools import lru_cache from pathlib import Path -from typing import Any, cast +from typing import Any -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.tools.utils.yaml_utils import _load_yaml_file -from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.graph import Graph -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_engine.layers.base import GraphEngineLayer -from dify_graph.graph_events import ( +from graphon.entities import GraphInitParams +from graphon.graph import Graph +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphEngineEvent, GraphRunStartedEvent, GraphRunSucceededEvent, ) -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ( +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import ( ArrayNumberVariable, ArrayObjectVariable, ArrayStringVariable, @@ -44,6 +39,12 @@ from dify_graph.variables import ( StringVariable, ) +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.tools.utils.yaml_utils import _load_yaml_file +from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id +from core.workflow.system_variables import build_bootstrap_variables, build_system_variables +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool + from .test_mock_config import MockConfig from .test_mock_factory import MockNodeFactory @@ -60,20 +61,28 @@ class _TableTestChildEngineBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> GraphEngine: + child_graph_runtime_state = GraphRuntimeState( + variable_pool=variable_pool if variable_pool is not None else parent_graph_runtime_state.variable_pool, + start_at=time.perf_counter(), + execution_context=parent_graph_runtime_state.execution_context, + ) if self._use_mock_factory: node_factory = MockNodeFactory( graph_init_params=graph_init_params, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, mock_config=self._mock_config, ) else: - node_factory = DifyNodeFactory(graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state) + node_factory = DifyNodeFactory( + graph_init_params=graph_init_params, + graph_runtime_state=child_graph_runtime_state, + ) + graph_config = graph_init_params.graph_config child_graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id=root_node_id) if not child_graph: raise ValueError("child graph not found") @@ -81,13 +90,11 @@ class _TableTestChildEngineBuilder: child_engine = GraphEngine( workflow_id=workflow_id, graph=child_graph, - graph_runtime_state=graph_runtime_state, + graph_runtime_state=child_graph_runtime_state, command_channel=InMemoryChannel(), config=GraphEngineConfig(), child_engine_builder=self, ) - for layer in layers: - child_engine.layer(cast(GraphEngineLayer, layer)) return child_engine @@ -206,14 +213,15 @@ class WorkflowRunner: call_depth=0, ) - system_variables = SystemVariable( + system_variables = build_system_variables( user_id="test_user", app_id="test_app", workflow_id=graph_init_params.workflow_id, files=[], query=query, ) - user_inputs = inputs if inputs is not None else {} + root_node_inputs = dict(inputs or {}) + root_node_inputs.setdefault("query", query) # Extract conversation variables from workflow config conversation_variables = [] @@ -242,11 +250,16 @@ class WorkflowRunner: ) conversation_variables.append(var) - variable_pool = VariablePool( - system_variables=system_variables, - user_inputs=user_inputs, - conversation_variables=conversation_variables, + root_node_id = get_default_root_node_id(graph_config) + variable_pool = VariablePool() + add_variables_to_pool( + variable_pool, + build_bootstrap_variables( + system_variables=system_variables, + conversation_variables=conversation_variables, + ), ) + add_node_inputs_to_pool(variable_pool, node_id=root_node_id, inputs=root_node_inputs) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) @@ -260,7 +273,7 @@ class WorkflowRunner: graph = Graph.init( graph_config=graph_config, node_factory=node_factory, - root_node_id=get_default_root_node_id(graph_config), + root_node_id=root_node_id, ) return graph, graph_runtime_state diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py index 7f26bc11a7..12aec6edf2 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_tool_in_chatflow.py @@ -1,6 +1,6 @@ -from dify_graph.graph_engine import GraphEngine, GraphEngineConfig -from dify_graph.graph_engine.command_channels import InMemoryChannel -from dify_graph.graph_events import ( +from graphon.graph_engine import GraphEngine, GraphEngineConfig +from graphon.graph_engine.command_channels import InMemoryChannel +from graphon.graph_events import ( GraphRunSucceededEvent, NodeRunStreamChunkEvent, ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py b/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py deleted file mode 100644 index a7309f64de..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_update_conversation_variable_iteration.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Validate conversation variable updates inside an iteration workflow. - -This test uses the ``update-conversation-variable-in-iteration`` fixture, which -routes ``sys.query`` into the conversation variable ``answer`` from within an -iteration container. The workflow should surface that updated conversation -variable in the final answer output. - -Code nodes in the fixture are mocked because their concrete outputs are not -relevant to verifying variable propagation semantics. -""" - -from .test_mock_config import MockConfigBuilder -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -def test_update_conversation_variable_in_iteration(): - fixture_name = "update-conversation-variable-in-iteration" - user_query = "ensure conversation variable syncs" - - mock_config = ( - MockConfigBuilder() - .with_node_output("1759032363865", {"result": [1]}) - .with_node_output("1759032476318", {"result": ""}) - .build() - ) - - case = WorkflowTestCase( - fixture_path=fixture_name, - use_auto_mock=True, - mock_config=mock_config, - query=user_query, - expected_outputs={"answer": user_query}, - description="Conversation variable updated within iteration should flow to answer output.", - ) - - runner = TableTestRunner() - result = runner.run_test_case(case) - - assert result.success, f"Workflow execution failed: {result.error}" - assert result.actual_outputs is not None - assert result.actual_outputs.get("answer") == user_query diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py b/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py deleted file mode 100644 index f63e8ff4ce..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_variable_aggregator.py +++ /dev/null @@ -1,58 +0,0 @@ -from unittest.mock import patch - -import pytest - -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode - -from .test_table_runner import TableTestRunner, WorkflowTestCase - - -class TestVariableAggregator: - """Test cases for the variable aggregator workflow.""" - - @pytest.mark.parametrize( - ("switch1", "switch2", "expected_group1", "expected_group2", "description"), - [ - (0, 0, "switch 1 off", "switch 2 off", "Both switches off"), - (0, 1, "switch 1 off", "switch 2 on", "Switch1 off, Switch2 on"), - (1, 0, "switch 1 on", "switch 2 off", "Switch1 on, Switch2 off"), - (1, 1, "switch 1 on", "switch 2 on", "Both switches on"), - ], - ) - def test_variable_aggregator_combinations( - self, - switch1: int, - switch2: int, - expected_group1: str, - expected_group2: str, - description: str, - ) -> None: - """Test all four combinations of switch1 and switch2.""" - - def mock_template_transform_run(self): - """Mock the TemplateTransformNode._run() method to return results based on node title.""" - title = self._node_data.title - return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, inputs={}, outputs={"output": title}) - - with patch.object( - TemplateTransformNode, - "_run", - mock_template_transform_run, - ): - runner = TableTestRunner() - - test_case = WorkflowTestCase( - fixture_path="dual_switch_variable_aggregator_workflow", - inputs={"switch1": switch1, "switch2": switch2}, - expected_outputs={"group1": expected_group1, "group2": expected_group2}, - description=description, - ) - - result = runner.run_test_case(test_case) - - assert result.success, f"Test failed: {result.error}" - assert result.actual_outputs == test_case.expected_outputs, ( - f"Output mismatch: expected {test_case.expected_outputs}, got {result.actual_outputs}" - ) diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py b/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py deleted file mode 100644 index bc00b49fba..0000000000 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_worker.py +++ /dev/null @@ -1,145 +0,0 @@ -import queue -from collections.abc import Generator -from datetime import UTC, datetime, timedelta -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.graph_engine.ready_queue import InMemoryReadyQueue -from dify_graph.graph_engine.worker import Worker -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunStartedEvent - - -def test_build_fallback_failure_event_uses_naive_utc_and_failed_node_run_result(mocker) -> None: - fixed_time = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - mocker.patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=fixed_time) - - worker = Worker( - ready_queue=InMemoryReadyQueue(), - event_queue=queue.Queue(), - graph=MagicMock(), - layers=[], - ) - node = SimpleNamespace( - execution_id="exec-1", - id="node-1", - node_type=BuiltinNodeTypes.LLM, - ) - - event = worker._build_fallback_failure_event(node, RuntimeError("boom")) - - assert event.start_at == fixed_time - assert event.finished_at == fixed_time - assert event.error == "boom" - assert event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - assert event.node_run_result.error == "boom" - assert event.node_run_result.error_type == "RuntimeError" - - -def test_worker_fallback_failure_event_reuses_observed_start_time() -> None: - start_at = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - failure_time = start_at + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeNode: - execution_id = "exec-1" - id = "node-1" - node_type = BuiltinNodeTypes.LLM - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="LLM", - start_at=start_at, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"node-1": FakeNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["node-1"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 1: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == start_at - assert fallback_event.finished_at == failure_time - assert fallback_event.error == "queue boom" - assert fallback_event.node_run_result.status == WorkflowNodeExecutionStatus.FAILED - - -def test_worker_fallback_failure_event_ignores_nested_iteration_child_start_times() -> None: - parent_start = datetime(2024, 1, 1, 12, 0, 0, tzinfo=UTC).replace(tzinfo=None) - child_start = parent_start + timedelta(seconds=3) - failure_time = parent_start + timedelta(seconds=5) - captured_events: list[NodeRunFailedEvent | NodeRunStartedEvent] = [] - - class FakeIterationNode: - execution_id = "iteration-exec" - id = "iteration-node" - node_type = BuiltinNodeTypes.ITERATION - - def ensure_execution_id(self) -> str: - return self.execution_id - - def run(self) -> Generator[NodeRunStartedEvent, None, None]: - yield NodeRunStartedEvent( - id=self.execution_id, - node_id=self.id, - node_type=self.node_type, - node_title="Iteration", - start_at=parent_start, - ) - yield NodeRunStartedEvent( - id="child-exec", - node_id="child-node", - node_type=BuiltinNodeTypes.LLM, - node_title="LLM", - start_at=child_start, - in_iteration_id=self.id, - ) - - worker = Worker( - ready_queue=MagicMock(), - event_queue=MagicMock(), - graph=MagicMock(nodes={"iteration-node": FakeIterationNode()}), - layers=[], - ) - - worker._ready_queue.get.side_effect = ["iteration-node"] - - def put_side_effect(event: NodeRunFailedEvent | NodeRunStartedEvent) -> None: - captured_events.append(event) - if len(captured_events) == 2: - raise RuntimeError("queue boom") - worker.stop() - - worker._event_queue.put.side_effect = put_side_effect - - with patch("dify_graph.graph_engine.worker.naive_utc_now", return_value=failure_time): - worker.run() - - fallback_event = captured_events[-1] - - assert isinstance(fallback_event, NodeRunFailedEvent) - assert fallback_event.start_at == parent_start - assert fallback_event.finished_at == failure_time diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py new file mode 100644 index 0000000000..cbc920705c --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_message_transformer.py @@ -0,0 +1,34 @@ +from unittest.mock import patch + +from graphon.enums import BuiltinNodeTypes + +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.nodes.agent.message_transformer import AgentMessageTransformer + + +def test_transform_passes_conversation_id_to_tool_file_message_transformer() -> None: + messages = iter(()) + transformer = AgentMessageTransformer() + + with patch.object(ToolFileMessageTransformer, "transform_tool_invoke_messages", return_value=iter(())) as transform: + result = list( + transformer.transform( + messages=messages, + tool_info={}, + parameters_for_log={}, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + node_type=BuiltinNodeTypes.AGENT, + node_id="node-id", + node_execution_id="execution-id", + ) + ) + + assert len(result) == 2 + transform.assert_called_once_with( + messages=messages, + user_id="user-id", + tenant_id="tenant-id", + conversation_id="conversation-id", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py new file mode 100644 index 0000000000..59dd763b59 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/agent/test_runtime_support.py @@ -0,0 +1,50 @@ +from types import SimpleNamespace +from unittest.mock import Mock, patch + +from graphon.model_runtime.entities.model_entities import ModelType + +from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport + + +def test_fetch_model_reuses_single_model_assembly(): + provider_configuration = SimpleNamespace( + get_current_credentials=Mock(return_value={"api_key": "x"}), + provider=SimpleNamespace(provider="openai"), + ) + model_type_instance = SimpleNamespace(get_model_schema=Mock(return_value="schema")) + provider_model_bundle = SimpleNamespace( + configuration=provider_configuration, + model_type_instance=model_type_instance, + ) + model_instance = Mock() + assembly = SimpleNamespace( + provider_manager=Mock(), + model_manager=Mock(), + ) + assembly.provider_manager.get_provider_model_bundle.return_value = provider_model_bundle + assembly.model_manager.get_model_instance.return_value = model_instance + + with patch( + "core.workflow.nodes.agent.runtime_support.create_plugin_model_assembly", + return_value=assembly, + ) as mock_assembly: + resolved_instance, resolved_schema = AgentRuntimeSupport().fetch_model( + tenant_id="tenant-1", + user_id="user-1", + value={"provider": "openai", "model": "gpt-4o-mini", "model_type": "llm"}, + ) + + assert resolved_instance is model_instance + assert resolved_schema == "schema" + mock_assembly.assert_called_once_with(tenant_id="tenant-1", user_id="user-1") + assembly.provider_manager.get_provider_model_bundle.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + ) + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4o-mini", + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py index fd563d1be2..7195471eb6 100644 --- a/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py +++ b/api/tests/unit_tests/core/workflow/nodes/answer/test_answer.py @@ -2,13 +2,14 @@ import time import uuid from unittest.mock import MagicMock +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.answer.answer_node import AnswerNode +from graphon.runtime import GraphRuntimeState, VariablePool + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.answer.answer_node import AnswerNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -48,7 +49,7 @@ def test_execute_answer(): # construct variable pool variable_pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py index 81d3f5be9c..343bcd3919 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_base_node.py @@ -1,9 +1,9 @@ import pytest +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node from core.workflow.node_factory import get_node_type_classes_mapping -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node # Ensures that all production node classes are imported and registered. _ = get_node_type_classes_mapping() diff --git a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py index 972a945ca0..b9371a34f4 100644 --- a/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py +++ b/api/tests/unit_tests/core/workflow/nodes/base/test_get_node_type_classes_mapping.py @@ -1,19 +1,20 @@ import types from collections.abc import Mapping -from core.workflow.node_factory import get_node_type_classes_mapping -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.enums import BuiltinNodeTypes, NodeType -from dify_graph.nodes.base.node import Node +from graphon.entities.base_node_data import BaseNodeData +from graphon.enums import BuiltinNodeTypes, NodeType +from graphon.nodes.base.node import Node # Import concrete nodes we will assert on (numeric version path) -from dify_graph.nodes.variable_assigner.v1.node import ( +from graphon.nodes.variable_assigner.v1.node import ( VariableAssignerNode as VariableAssignerV1, ) -from dify_graph.nodes.variable_assigner.v2.node import ( +from graphon.nodes.variable_assigner.v2.node import ( VariableAssignerNode as VariableAssignerV2, ) +from core.workflow.node_factory import get_node_type_classes_mapping + def test_variable_assigner_latest_prefers_highest_numeric_version(): # Act diff --git a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py index 784e08edd2..d155124c50 100644 --- a/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/code/code_node_spec.py @@ -1,13 +1,14 @@ -from configs import dify_config -from dify_graph.nodes.code.code_node import CodeNode -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.nodes.code.exc import ( +from graphon.nodes.code.code_node import CodeNode +from graphon.nodes.code.entities import CodeLanguage, CodeNodeData +from graphon.nodes.code.exc import ( CodeNodeError, DepthLimitError, OutputValidationError, ) -from dify_graph.nodes.code.limits import CodeNodeLimits -from dify_graph.variables.types import SegmentType +from graphon.nodes.code.limits import CodeNodeLimits +from graphon.variables.types import SegmentType + +from configs import dify_config CodeNode._limits = CodeNodeLimits( max_string_length=dify_config.CODE_MAX_STRING_LENGTH, diff --git a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py deleted file mode 100644 index de7ed0815e..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/code/entities_spec.py +++ /dev/null @@ -1,352 +0,0 @@ -import pytest -from pydantic import ValidationError - -from dify_graph.nodes.code.entities import CodeLanguage, CodeNodeData -from dify_graph.variables.types import SegmentType - - -class TestCodeNodeDataOutput: - """Test suite for CodeNodeData.Output model.""" - - def test_output_with_string_type(self): - """Test Output with STRING type.""" - output = CodeNodeData.Output(type=SegmentType.STRING) - - assert output.type == SegmentType.STRING - assert output.children is None - - def test_output_with_number_type(self): - """Test Output with NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.NUMBER) - - assert output.type == SegmentType.NUMBER - assert output.children is None - - def test_output_with_boolean_type(self): - """Test Output with BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.BOOLEAN) - - assert output.type == SegmentType.BOOLEAN - - def test_output_with_object_type(self): - """Test Output with OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.OBJECT) - - assert output.type == SegmentType.OBJECT - - def test_output_with_array_string_type(self): - """Test Output with ARRAY_STRING type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_STRING) - - assert output.type == SegmentType.ARRAY_STRING - - def test_output_with_array_number_type(self): - """Test Output with ARRAY_NUMBER type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER) - - assert output.type == SegmentType.ARRAY_NUMBER - - def test_output_with_array_object_type(self): - """Test Output with ARRAY_OBJECT type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_OBJECT) - - assert output.type == SegmentType.ARRAY_OBJECT - - def test_output_with_array_boolean_type(self): - """Test Output with ARRAY_BOOLEAN type.""" - output = CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN) - - assert output.type == SegmentType.ARRAY_BOOLEAN - - def test_output_with_nested_children(self): - """Test Output with nested children for OBJECT type.""" - child_output = CodeNodeData.Output(type=SegmentType.STRING) - parent_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"name": child_output}, - ) - - assert parent_output.type == SegmentType.OBJECT - assert parent_output.children is not None - assert "name" in parent_output.children - assert parent_output.children["name"].type == SegmentType.STRING - - def test_output_with_deeply_nested_children(self): - """Test Output with deeply nested children.""" - inner_child = CodeNodeData.Output(type=SegmentType.NUMBER) - middle_child = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"value": inner_child}, - ) - outer_output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={"nested": middle_child}, - ) - - assert outer_output.children is not None - assert outer_output.children["nested"].children is not None - assert outer_output.children["nested"].children["value"].type == SegmentType.NUMBER - - def test_output_with_multiple_children(self): - """Test Output with multiple children.""" - output = CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - "active": CodeNodeData.Output(type=SegmentType.BOOLEAN), - }, - ) - - assert output.children is not None - assert len(output.children) == 3 - assert output.children["name"].type == SegmentType.STRING - assert output.children["age"].type == SegmentType.NUMBER - assert output.children["active"].type == SegmentType.BOOLEAN - - def test_output_rejects_invalid_type(self): - """Test Output rejects invalid segment types.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.FILE) - - def test_output_rejects_array_file_type(self): - """Test Output rejects ARRAY_FILE type.""" - with pytest.raises(ValidationError): - CodeNodeData.Output(type=SegmentType.ARRAY_FILE) - - -class TestCodeNodeDataDependency: - """Test suite for CodeNodeData.Dependency model.""" - - def test_dependency_basic(self): - """Test Dependency with name and version.""" - dependency = CodeNodeData.Dependency(name="numpy", version="1.24.0") - - assert dependency.name == "numpy" - assert dependency.version == "1.24.0" - - def test_dependency_with_complex_version(self): - """Test Dependency with complex version string.""" - dependency = CodeNodeData.Dependency(name="pandas", version=">=2.0.0,<3.0.0") - - assert dependency.name == "pandas" - assert dependency.version == ">=2.0.0,<3.0.0" - - def test_dependency_with_empty_version(self): - """Test Dependency with empty version.""" - dependency = CodeNodeData.Dependency(name="requests", version="") - - assert dependency.name == "requests" - assert dependency.version == "" - - -class TestCodeNodeData: - """Test suite for CodeNodeData model.""" - - def test_code_node_data_python3(self): - """Test CodeNodeData with Python3 language.""" - data = CodeNodeData( - title="Test Code Node", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'result': 42}", - outputs={"result": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert data.title == "Test Code Node" - assert data.code_language == CodeLanguage.PYTHON3 - assert data.code == "def main(): return {'result': 42}" - assert "result" in data.outputs - assert data.dependencies is None - - def test_code_node_data_javascript(self): - """Test CodeNodeData with JavaScript language.""" - data = CodeNodeData( - title="JS Code Node", - variables=[], - code_language=CodeLanguage.JAVASCRIPT, - code="function main() { return { result: 'hello' }; }", - outputs={"result": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert data.code_language == CodeLanguage.JAVASCRIPT - assert "result" in data.outputs - assert data.outputs["result"].type == SegmentType.STRING - - def test_code_node_data_with_dependencies(self): - """Test CodeNodeData with dependencies.""" - data = CodeNodeData( - title="Code with Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="import numpy as np\ndef main(): return {'sum': 10}", - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - dependencies=[ - CodeNodeData.Dependency(name="numpy", version="1.24.0"), - CodeNodeData.Dependency(name="pandas", version="2.0.0"), - ], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 2 - assert data.dependencies[0].name == "numpy" - assert data.dependencies[1].name == "pandas" - - def test_code_node_data_with_multiple_outputs(self): - """Test CodeNodeData with multiple outputs.""" - data = CodeNodeData( - title="Multi Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'name': 'test', 'count': 5, 'items': ['a', 'b']}", - outputs={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "count": CodeNodeData.Output(type=SegmentType.NUMBER), - "items": CodeNodeData.Output(type=SegmentType.ARRAY_STRING), - }, - ) - - assert len(data.outputs) == 3 - assert data.outputs["name"].type == SegmentType.STRING - assert data.outputs["count"].type == SegmentType.NUMBER - assert data.outputs["items"].type == SegmentType.ARRAY_STRING - - def test_code_node_data_with_object_output(self): - """Test CodeNodeData with nested object output.""" - data = CodeNodeData( - title="Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'user': {'name': 'John', 'age': 30}}", - outputs={ - "user": CodeNodeData.Output( - type=SegmentType.OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - "age": CodeNodeData.Output(type=SegmentType.NUMBER), - }, - ), - }, - ) - - assert data.outputs["user"].type == SegmentType.OBJECT - assert data.outputs["user"].children is not None - assert len(data.outputs["user"].children) == 2 - - def test_code_node_data_with_array_object_output(self): - """Test CodeNodeData with array of objects output.""" - data = CodeNodeData( - title="Array Object Output", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'users': [{'name': 'A'}, {'name': 'B'}]}", - outputs={ - "users": CodeNodeData.Output( - type=SegmentType.ARRAY_OBJECT, - children={ - "name": CodeNodeData.Output(type=SegmentType.STRING), - }, - ), - }, - ) - - assert data.outputs["users"].type == SegmentType.ARRAY_OBJECT - assert data.outputs["users"].children is not None - - def test_code_node_data_empty_code(self): - """Test CodeNodeData with empty code.""" - data = CodeNodeData( - title="Empty Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="", - outputs={}, - ) - - assert data.code == "" - assert len(data.outputs) == 0 - - def test_code_node_data_multiline_code(self): - """Test CodeNodeData with multiline code.""" - multiline_code = """ -def main(): - result = 0 - for i in range(10): - result += i - return {'sum': result} -""" - data = CodeNodeData( - title="Multiline Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=multiline_code, - outputs={"sum": CodeNodeData.Output(type=SegmentType.NUMBER)}, - ) - - assert "for i in range(10)" in data.code - assert "result += i" in data.code - - def test_code_node_data_with_special_characters_in_code(self): - """Test CodeNodeData with special characters in code.""" - code_with_special = "def main(): return {'msg': 'Hello\\nWorld\\t!'}" - data = CodeNodeData( - title="Special Chars", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=code_with_special, - outputs={"msg": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "\\n" in data.code - assert "\\t" in data.code - - def test_code_node_data_with_unicode_in_code(self): - """Test CodeNodeData with unicode characters in code.""" - unicode_code = "def main(): return {'greeting': '你好世界'}" - data = CodeNodeData( - title="Unicode Code", - variables=[], - code_language=CodeLanguage.PYTHON3, - code=unicode_code, - outputs={"greeting": CodeNodeData.Output(type=SegmentType.STRING)}, - ) - - assert "你好世界" in data.code - - def test_code_node_data_empty_dependencies_list(self): - """Test CodeNodeData with empty dependencies list.""" - data = CodeNodeData( - title="No Deps", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {}", - outputs={}, - dependencies=[], - ) - - assert data.dependencies is not None - assert len(data.dependencies) == 0 - - def test_code_node_data_with_boolean_array_output(self): - """Test CodeNodeData with boolean array output.""" - data = CodeNodeData( - title="Boolean Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'flags': [True, False, True]}", - outputs={"flags": CodeNodeData.Output(type=SegmentType.ARRAY_BOOLEAN)}, - ) - - assert data.outputs["flags"].type == SegmentType.ARRAY_BOOLEAN - - def test_code_node_data_with_number_array_output(self): - """Test CodeNodeData with number array output.""" - data = CodeNodeData( - title="Number Array", - variables=[], - code_language=CodeLanguage.PYTHON3, - code="def main(): return {'values': [1, 2, 3, 4, 5]}", - outputs={"values": CodeNodeData.Output(type=SegmentType.ARRAY_NUMBER)}, - ) - - assert data.outputs["values"].type == SegmentType.ARRAY_NUMBER diff --git a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py index 859115ceb3..fb03ae9998 100644 --- a/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/datasource/test_datasource_node.py @@ -1,7 +1,8 @@ +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY from core.workflow.nodes.datasource.datasource_node import DatasourceNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.node_events import NodeRunResult, StreamChunkEvent, StreamCompletedEvent class _VarSeg: diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py deleted file mode 100644 index cd822a6f89..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_config.py +++ /dev/null @@ -1,33 +0,0 @@ -from dify_graph.nodes.http_request import build_http_request_config - - -def test_build_http_request_config_uses_literal_defaults(): - config = build_http_request_config() - - assert config.max_connect_timeout == 10 - assert config.max_read_timeout == 600 - assert config.max_write_timeout == 600 - assert config.max_binary_size == 10 * 1024 * 1024 - assert config.max_text_size == 1 * 1024 * 1024 - assert config.ssl_verify is True - assert config.ssrf_default_max_retries == 3 - - -def test_build_http_request_config_supports_explicit_overrides(): - config = build_http_request_config( - max_connect_timeout=5, - max_read_timeout=30, - max_write_timeout=40, - max_binary_size=2048, - max_text_size=1024, - ssl_verify=False, - ssrf_default_max_retries=8, - ) - - assert config.max_connect_timeout == 5 - assert config.max_read_timeout == 30 - assert config.max_write_timeout == 40 - assert config.max_binary_size == 2048 - assert config.max_text_size == 1024 - assert config.ssl_verify is False - assert config.ssrf_default_max_retries == 8 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py deleted file mode 100644 index fec6ad90eb..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_entities.py +++ /dev/null @@ -1,233 +0,0 @@ -import json -from unittest.mock import Mock, PropertyMock, patch - -import httpx -import pytest - -from dify_graph.nodes.http_request.entities import Response - - -@pytest.fixture -def mock_response(): - response = Mock(spec=httpx.Response) - response.headers = {} - return response - - -def test_is_file_with_attachment_disposition(mock_response): - """Test is_file when content-disposition header contains 'attachment'""" - mock_response.headers = {"content-disposition": "attachment; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_filename_disposition(mock_response): - """Test is_file when content-disposition header contains filename parameter""" - mock_response.headers = {"content-disposition": "inline; filename=test.pdf", "content-type": "application/pdf"} - response = Response(mock_response) - assert response.is_file - - -@pytest.mark.parametrize("content_type", ["application/pdf", "image/jpeg", "audio/mp3", "video/mp4"]) -def test_is_file_with_file_content_types(mock_response, content_type): - """Test is_file with various file content types""" - mock_response.headers = {"content-type": content_type} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file, f"Content type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - "content_type", - [ - "application/json", - "application/xml", - "application/javascript", - "application/x-www-form-urlencoded", - "application/yaml", - "application/graphql", - ], -) -def test_text_based_application_types(mock_response, content_type): - """Test common text-based application types are not identified as files""" - mock_response.headers = {"content-type": content_type} - response = Response(mock_response) - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (b'{"key": "value"}', "application/octet-stream"), - (b"[1, 2, 3]", "application/unknown"), - (b"function test() {}", "application/x-unknown"), - (b"test", "application/binary"), - (b"var x = 1;", "application/data"), - ], -) -def test_content_based_detection(mock_response, content, content_type): - """Test content-based detection for text-like content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert not response.is_file, f"Content {content} with type {content_type} should not be identified as a file" - - -@pytest.mark.parametrize( - ("content", "content_type"), - [ - (bytes([0x00, 0xFF] * 512), "application/octet-stream"), - (bytes([0x89, 0x50, 0x4E, 0x47]), "application/unknown"), # PNG magic numbers - (bytes([0xFF, 0xD8, 0xFF]), "application/binary"), # JPEG magic numbers - ], -) -def test_binary_content_detection(mock_response, content, content_type): - """Test content-based detection for binary content""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=content) - response = Response(mock_response) - assert response.is_file, f"Binary content with type {content_type} should be identified as a file" - - -@pytest.mark.parametrize( - ("content_type", "expected_main_type"), - [ - ("x-world/x-vrml", "model"), # VRML 3D model - ("font/ttf", "application"), # TrueType font - ("text/csv", "text"), # CSV text file - ("unknown/xyz", None), # Unknown type - ], -) -def test_mimetype_based_detection(mock_response, content_type, expected_main_type): - """Test detection using mimetypes.guess_type for non-application content types""" - mock_response.headers = {"content-type": content_type} - type(mock_response).content = PropertyMock(return_value=bytes([0x00])) # Dummy content - - with patch("dify_graph.nodes.http_request.entities.mimetypes.guess_type") as mock_guess_type: - # Mock the return value based on expected_main_type - if expected_main_type: - mock_guess_type.return_value = (f"{expected_main_type}/subtype", None) - else: - mock_guess_type.return_value = (None, None) - - response = Response(mock_response) - - # Check if the result matches our expectation - if expected_main_type in ("application", "image", "audio", "video"): - assert response.is_file, f"Content type {content_type} should be identified as a file" - else: - assert not response.is_file, f"Content type {content_type} should not be identified as a file" - - # Verify that guess_type was called - mock_guess_type.assert_called_once() - - -def test_is_file_with_inline_disposition(mock_response): - """Test is_file when content-disposition is 'inline'""" - mock_response.headers = {"content-disposition": "inline", "content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -def test_is_file_with_no_content_disposition(mock_response): - """Test is_file when no content-disposition header is present""" - mock_response.headers = {"content-type": "application/pdf"} - # Mock binary content - type(mock_response).content = PropertyMock(return_value=bytes([0x00, 0xFF] * 512)) - response = Response(mock_response) - assert response.is_file - - -# UTF-8 Encoding Tests -@pytest.mark.parametrize( - ("content_bytes", "expected_text", "description"), - [ - # Chinese UTF-8 bytes - ( - b'{"message": "\xe4\xbd\xa0\xe5\xa5\xbd\xe4\xb8\x96\xe7\x95\x8c"}', - '{"message": "你好世界"}', - "Chinese characters UTF-8", - ), - # Japanese UTF-8 bytes - ( - b'{"message": "\xe3\x81\x93\xe3\x82\x93\xe3\x81\xab\xe3\x81\xa1\xe3\x81\xaf"}', - '{"message": "こんにちは"}', - "Japanese characters UTF-8", - ), - # Korean UTF-8 bytes - ( - b'{"message": "\xec\x95\x88\xeb\x85\x95\xed\x95\x98\xec\x84\xb8\xec\x9a\x94"}', - '{"message": "안녕하세요"}', - "Korean characters UTF-8", - ), - # Arabic UTF-8 - (b'{"text": "\xd9\x85\xd8\xb1\xd8\xad\xd8\xa8\xd8\xa7"}', '{"text": "مرحبا"}', "Arabic characters UTF-8"), - # European characters UTF-8 - (b'{"text": "Caf\xc3\xa9 M\xc3\xbcnchen"}', '{"text": "Café München"}', "European accented characters"), - # Simple ASCII - (b'{"text": "Hello World"}', '{"text": "Hello World"}', "Simple ASCII text"), - ], -) -def test_text_property_utf8_decoding(mock_response, content_bytes, expected_text, description): - """Test that Response.text properly decodes UTF-8 content with charset_normalizer""" - mock_response.headers = {"content-type": "application/json; charset=utf-8"} - type(mock_response).content = PropertyMock(return_value=content_bytes) - # Mock httpx response.text to return something different (simulating potential encoding issues) - mock_response.text = "incorrect-fallback-text" # To ensure we are not falling back to httpx's text property - - response = Response(mock_response) - - # Our enhanced text property should decode properly using charset_normalizer - assert response.text == expected_text, ( - f"Failed for {description}: got {repr(response.text)}, expected {repr(expected_text)}" - ) - - -def test_text_property_fallback_to_httpx(mock_response): - """Test that Response.text falls back to httpx.text when charset_normalizer fails""" - mock_response.headers = {"content-type": "application/json"} - - # Create malformed UTF-8 bytes - malformed_bytes = b'{"text": "\xff\xfe\x00\x00 invalid"}' - type(mock_response).content = PropertyMock(return_value=malformed_bytes) - - # Mock httpx.text to return some fallback value - fallback_text = '{"text": "fallback"}' - mock_response.text = fallback_text - - response = Response(mock_response) - - # Should fall back to httpx's text when charset_normalizer fails - assert response.text == fallback_text - - -@pytest.mark.parametrize( - ("json_content", "description"), - [ - # JSON with escaped Unicode (like Flask jsonify()) - ('{"message": "\\u4f60\\u597d\\u4e16\\u754c"}', "JSON with escaped Unicode"), - # JSON with mixed escape sequences and UTF-8 - ('{"mixed": "Hello \\u4f60\\u597d"}', "Mixed escaped and regular text"), - # JSON with complex escape sequences - ('{"complex": "\\ud83d\\ude00\\u4f60\\u597d"}', "Emoji and Chinese escapes"), - ], -) -def test_text_property_with_escaped_unicode(mock_response, json_content, description): - """Test Response.text with JSON containing Unicode escape sequences""" - mock_response.headers = {"content-type": "application/json"} - - content_bytes = json_content.encode("utf-8") - type(mock_response).content = PropertyMock(return_value=content_bytes) - mock_response.text = json_content # httpx would return the same for valid UTF-8 - - response = Response(mock_response) - - # Should preserve the escape sequences (valid JSON) - assert response.text == json_content, f"Failed for {description}" - - # The text should be valid JSON that can be parsed back to proper Unicode - parsed = json.loads(response.text) - assert isinstance(parsed, dict), f"Invalid JSON for {description}" diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index cea7195417..a5026b40cf 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,20 +1,20 @@ import pytest - -from configs import dify_config -from core.helper.ssrf_proxy import ssrf_proxy -from dify_graph.file.file_manager import file_manager -from dify_graph.nodes.http_request import ( +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, HttpRequestNodeBody, HttpRequestNodeConfig, HttpRequestNodeData, ) -from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout -from dify_graph.nodes.http_request.exc import AuthorizationConfigError -from dify_graph.nodes.http_request.executor import Executor -from dify_graph.runtime import VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout +from graphon.nodes.http_request.exc import AuthorizationConfigError +from graphon.nodes.http_request.executor import Executor +from graphon.runtime import VariablePool + +from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.system_variables import default_system_variables HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( max_connect_timeout=dify_config.HTTP_REQUEST_MAX_CONNECT_TIMEOUT, @@ -30,7 +30,7 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( def test_executor_with_json_body_and_number_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "number"], 42) @@ -86,7 +86,7 @@ def test_executor_with_json_body_and_number_variable(): def test_executor_with_json_body_and_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -144,7 +144,7 @@ def test_executor_with_json_body_and_object_variable(): def test_executor_with_json_body_and_nested_object_variable(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "object"], {"name": "John Doe", "age": 30, "email": "john@example.com"}) @@ -201,7 +201,7 @@ def test_executor_with_json_body_and_nested_object_variable(): def test_extract_selectors_from_template_with_newline(): - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) variable_pool.add(("node_id", "custom_query"), "line1\nline2") node_data = HttpRequestNodeData( title="Test JSON Body with Nested Object Variable", @@ -231,7 +231,7 @@ def test_extract_selectors_from_template_with_newline(): def test_executor_with_form_data(): # Prepare the variable pool variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "text_field"], "Hello, World!") @@ -320,7 +320,7 @@ def test_init_headers(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -357,7 +357,7 @@ def test_init_params(): node_data=node_data, timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, - variable_pool=VariablePool(system_variables=SystemVariable.default()), + variable_pool=VariablePool(system_variables=default_system_variables()), http_client=ssrf_proxy, file_manager=file_manager, ) @@ -390,7 +390,7 @@ def test_init_params(): def test_empty_api_key_raises_error_bearer(): """Test that empty API key raises AuthorizationConfigError for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -417,7 +417,7 @@ def test_empty_api_key_raises_error_bearer(): def test_empty_api_key_raises_error_basic(): """Test that empty API key raises AuthorizationConfigError for basic auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -444,7 +444,7 @@ def test_empty_api_key_raises_error_basic(): def test_empty_api_key_raises_error_custom(): """Test that empty API key raises AuthorizationConfigError for custom auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -471,7 +471,7 @@ def test_empty_api_key_raises_error_custom(): def test_whitespace_only_api_key_raises_error(): """Test that whitespace-only API key raises AuthorizationConfigError.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -498,7 +498,7 @@ def test_whitespace_only_api_key_raises_error(): def test_valid_api_key_works(): """Test that valid API key works correctly for bearer auth.""" - variable_pool = VariablePool(system_variables=SystemVariable.default()) + variable_pool = VariablePool(system_variables=default_system_variables()) node_data = HttpRequestNodeData( title="test", method="get", @@ -537,7 +537,7 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -584,7 +584,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): test_uuid = "57eeeeb1-450b-482c-81b9-4be77e95dee2" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["pre_node_id", "uuid"], test_uuid) @@ -625,7 +625,7 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): def test_executor_with_json_body_preserves_numbers_and_strings(): """Test that numbers are preserved and string values are properly quoted.""" variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) variable_pool.add(["node", "count"], 42) diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 5e34bf1d94..4705b3f76e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -3,16 +3,17 @@ from typing import Any import httpx import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file.file_manager import file_manager +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.nodes.http_request.entities import HttpRequestNodeTimeout, Response +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.helper.ssrf_proxy import ssrf_proxy from core.tools.tool_file_manager import ToolFileManager -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file.file_manager import file_manager -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from dify_graph.nodes.http_request.entities import HttpRequestNodeTimeout, Response -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import DifyFileReferenceFactory +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params HTTP_REQUEST_CONFIG = HttpRequestNodeConfig( @@ -109,7 +110,7 @@ def _build_http_node( call_depth=0, ) graph_runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=time.perf_counter(), ) return HttpRequestNode( @@ -121,6 +122,7 @@ def _build_http_node( http_client=ssrf_proxy, tool_file_manager_factory=ToolFileManager, file_manager=file_manager, + file_reference_factory=DifyFileReferenceFactory(graph_init_params.run_context), ) @@ -161,7 +163,7 @@ def test_run_passes_node_data_ssl_verify_to_executor(monkeypatch: pytest.MonkeyP ) ) - monkeypatch.setattr("dify_graph.nodes.http_request.node.Executor", FakeExecutor) + monkeypatch.setattr("graphon.nodes.http_request.node.Executor", FakeExecutor) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py index d52dfa2a65..d16e1233ac 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_email_delivery_config.py @@ -1,5 +1,6 @@ -from dify_graph.nodes.human_input.entities import EmailDeliveryConfig, EmailRecipients -from dify_graph.runtime import VariablePool +from graphon.runtime import VariablePool + +from core.workflow.human_input_compat import EmailDeliveryConfig, EmailRecipients def test_render_body_template_replaces_variable_values(): diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py index 55aa62a1c0..a2cdbbf132 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_entities.py @@ -2,42 +2,138 @@ Unit tests for human input node entities. """ +from collections.abc import Mapping +from dataclasses import dataclass, field +from datetime import datetime, timedelta from types import SimpleNamespace +from typing import Any from unittest.mock import MagicMock import pytest -from pydantic import ValidationError - -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.node_events import PauseRequestedEvent -from dify_graph.node_events.node import StreamCompletedEvent -from dify_graph.nodes.human_input.entities import ( - EmailDeliveryConfig, - EmailDeliveryMethod, - EmailRecipients, - ExternalRecipient, +from graphon.entities import GraphInitParams +from graphon.node_events import PauseRequestedEvent +from graphon.node_events.node import StreamCompletedEvent +from graphon.nodes.human_input.entities import ( FormInput, FormInputDefault, HumanInputNodeData, - MemberRecipient, UserAction, - WebAppDeliveryMethod, - _WebAppDeliveryConfig, ) -from dify_graph.nodes.human_input.enums import ( +from graphon.nodes.human_input.enums import ( ButtonStyle, - DeliveryMethodType, - EmailRecipientType, FormInputType, + HumanInputFormStatus, PlaceholderType, TimeoutUnit, ) -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.repositories.human_input_form_repository import HumanInputFormRepository -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.unit_tests.core.workflow.graph_engine.human_input_test_utils import InMemoryHumanInputFormRepository +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool +from pydantic import ValidationError + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY +from core.repositories.human_input_repository import ( + FormCreateParams, + HumanInputFormEntity, + HumanInputFormRecipientEntity, + HumanInputFormRepository, +) +from core.workflow.human_input_compat import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + EmailRecipientType, + ExternalRecipient, + MemberRecipient, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, +) +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import build_system_variables +from libs.datetime_utils import naive_utc_now + + +@dataclass +class _InMemoryFormEntity(HumanInputFormEntity): + form_id: str + rendered: str + token: str | None = None + action_id: str | None = None + data: Mapping[str, Any] | None = None + is_submitted: bool = False + status_value: HumanInputFormStatus = HumanInputFormStatus.WAITING + expiration: datetime = field(default_factory=lambda: naive_utc_now() + timedelta(days=1)) + + @property + def id(self) -> str: + return self.form_id + + @property + def submission_token(self) -> str | None: + return self.token + + @property + def recipients(self) -> list[HumanInputFormRecipientEntity]: + return [] + + @property + def rendered_content(self) -> str: + return self.rendered + + @property + def selected_action_id(self) -> str | None: + return self.action_id + + @property + def submitted_data(self) -> Mapping[str, Any] | None: + return self.data + + @property + def submitted(self) -> bool: + return self.is_submitted + + @property + def status(self) -> HumanInputFormStatus: + return self.status_value + + @property + def expiration_time(self) -> datetime: + return self.expiration + + +class InMemoryHumanInputFormRepository(HumanInputFormRepository): + """Minimal in-memory repository for Dify-owned HumanInputNode behavior tests.""" + + def __init__(self) -> None: + self._form_counter = 0 + self.created_params: list[FormCreateParams] = [] + self.created_forms: list[_InMemoryFormEntity] = [] + self._forms_by_node_id: dict[str, _InMemoryFormEntity] = {} + + def create_form(self, params: FormCreateParams) -> HumanInputFormEntity: + self.created_params.append(params) + self._form_counter += 1 + form_id = f"form-{self._form_counter}" + entity = _InMemoryFormEntity( + form_id=form_id, + rendered=params.rendered_content, + token=f"token-{form_id}", + ) + self.created_forms.append(entity) + self._forms_by_node_id[params.node_id] = entity + return entity + + def get_form(self, node_id: str) -> HumanInputFormEntity | None: + return self._forms_by_node_id.get(node_id) + + def set_submission(self, *, action_id: str, form_data: Mapping[str, Any] | None = None) -> None: + if not self.created_forms: + raise AssertionError("no form has been created to attach submission data") + entity = self.created_forms[-1] + entity.action_id = action_id + entity.data = form_data or {} + entity.is_submitted = True + entity.status_value = HumanInputFormStatus.SUBMITTED class TestDeliveryMethod: @@ -54,9 +150,9 @@ class TestDeliveryMethod: def test_email_delivery_method(self): """Test email delivery method creation.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="test-user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="test-user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="test@example.com"), ], ) @@ -193,7 +289,7 @@ class TestHumanInputNodeData: EmailDeliveryMethod( enabled=False, # Disabled method should be fine config=EmailDeliveryConfig( - subject="Hi there", body="", recipients=EmailRecipients(whole_workspace=True) + subject="Hi there", body="", recipients=EmailRecipients(include_bound_group=True) ), ), ] @@ -212,7 +308,7 @@ class TestHumanInputNodeData: assert node_data.title == "Test Node" assert node_data.desc is None - assert node_data.delivery_methods == [] + assert node_data.model_dump().get("delivery_methods") is None assert node_data.form_content == "" assert node_data.inputs == [] assert node_data.user_actions == [] @@ -261,10 +357,10 @@ class TestRecipients: def test_member_recipient(self): """Test member recipient creation.""" - recipient = MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + recipient = MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123") assert recipient.type == EmailRecipientType.MEMBER - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" def test_external_recipient(self): """Test external recipient creation.""" @@ -273,37 +369,46 @@ class TestRecipients: assert recipient.type == EmailRecipientType.EXTERNAL assert recipient.email == "test@example.com" - def test_email_recipients_whole_workspace(self): - """Test email recipients with whole workspace enabled.""" + def test_email_recipients_bound_group(self): + """Test email recipients with the bound group enabled.""" recipients = EmailRecipients( - whole_workspace=True, items=[MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123")] + include_bound_group=True, + items=[MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123")], ) - assert recipients.whole_workspace is True - assert len(recipients.items) == 1 # Items are preserved even when whole_workspace is True + assert recipients.include_bound_group is True + assert len(recipients.items) == 1 # Items are preserved even when include_bound_group is True def test_email_recipients_specific_users(self): """Test email recipients with specific users.""" recipients = EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ - MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123"), + MemberRecipient(type=EmailRecipientType.MEMBER, reference_id="user-123"), ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="external@example.com"), ], ) - assert recipients.whole_workspace is False + assert recipients.include_bound_group is False assert len(recipients.items) == 2 - assert recipients.items[0].user_id == "user-123" + assert recipients.items[0].reference_id == "user-123" assert recipients.items[1].email == "external@example.com" + def test_legacy_recipient_keys_are_rejected(self): + with pytest.raises(ValidationError): + MemberRecipient(type=EmailRecipientType.MEMBER, user_id="user-123") + + recipients = EmailRecipients(whole_workspace=True, items=[]) + assert recipients.include_bound_group is True + assert recipients.items == [] + class TestHumanInputNodeVariableResolution: """Tests for resolving variable-based defaults in HumanInputNode.""" def test_resolves_variable_defaults(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -353,17 +458,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-1", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -378,7 +485,7 @@ class TestHumanInputNodeVariableResolution: def test_debugger_falls_back_to_recipient_token_when_webapp_disabled(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -416,28 +523,96 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-2", rendered_content="Provide your name", - web_app_token="console-token", + submission_token="console-token", recipients=[SimpleNamespace(token="recipient-token")], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() pause_event = next(run_result) assert isinstance(pause_event, PauseRequestedEvent) - assert pause_event.reason.form_token == "console-token" + assert not hasattr(pause_event.reason, "form_token") + + def test_webapp_runtime_keeps_form_visible_in_ui_when_webapp_delivery_is_enabled(self): + variable_pool = VariablePool( + system_variables=build_system_variables( + user_id="user", + app_id="app", + workflow_id="workflow", + workflow_execution_id="exec-4", + ), + user_inputs={}, + conversation_variables=[], + ) + runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) + graph_init_params = GraphInitParams( + workflow_id="workflow", + graph_config={"nodes": [], "edges": []}, + run_context={ + DIFY_RUN_CONTEXT_KEY: { + "tenant_id": "tenant", + "app_id": "app", + "user_id": "end-user-1", + "user_from": "end-user", + "invoke_from": "web-app", + } + }, + call_depth=0, + ) + + config = { + "id": "human", + "data": { + "type": "human-input", + "title": "Human Input", + "form_content": "Provide your name", + "inputs": [], + "user_actions": [{"id": "submit", "title": "Submit"}], + "delivery_methods": [{"enabled": True, "type": "webapp", "config": {}}], + }, + } + + mock_repo = MagicMock(spec=HumanInputFormRepository) + mock_repo.get_form.return_value = None + mock_repo.create_form.return_value = SimpleNamespace( + id="form-4", + rendered_content="Provide your name", + submission_token="token", + recipients=[], + submitted=False, + ) + + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] + node = HumanInputNode( + id=config["id"], + config=config, + graph_init_params=graph_init_params, + graph_runtime_state=runtime_state, + runtime=runtime, + ) + + run_result = node._run() + pause_event = next(run_result) + + assert isinstance(pause_event, PauseRequestedEvent) + params = mock_repo.create_form.call_args.args[0] + assert params.display_in_ui is True def test_debugger_debug_mode_overrides_email_recipients(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user-123", app_id="app", workflow_id="workflow", @@ -472,7 +647,7 @@ class TestHumanInputNodeVariableResolution: enabled=True, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(type=EmailRecipientType.EXTERNAL, email="target@example.com")], ), subject="Subject", @@ -489,17 +664,19 @@ class TestHumanInputNodeVariableResolution: mock_repo.create_form.return_value = SimpleNamespace( id="form-3", rendered_content="Provide your name", - web_app_token="token", + submission_token="token", recipients=[], submitted=False, ) + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=mock_repo) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=mock_repo, + runtime=runtime, ) run_result = node._run() @@ -511,11 +688,11 @@ class TestHumanInputNodeVariableResolution: method = params.delivery_methods[0] assert isinstance(method, EmailDeliveryMethod) assert method.config.debug_mode is True - assert method.config.recipients.whole_workspace is False + assert method.config.recipients.include_bound_group is False assert len(method.config.recipients.items) == 1 recipient = method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == "user-123" + assert recipient.reference_id == "user-123" class TestValidation: @@ -552,7 +729,7 @@ class TestHumanInputNodeRenderedContent: def test_replaces_outputs_placeholders_after_submission(self): variable_pool = VariablePool( - system_variables=SystemVariable( + system_variables=build_system_variables( user_id="user", app_id="app", workflow_id="workflow", @@ -591,12 +768,14 @@ class TestHumanInputNodeRenderedContent: config = {"id": "human", "data": node_data.model_dump()} form_repository = InMemoryHumanInputFormRepository() + runtime = DifyHumanInputNodeRuntime(graph_init_params.run_context) + runtime._build_form_repository = MagicMock(return_value=form_repository) # type: ignore[attr-defined] node = HumanInputNode( id=config["id"], config=config, graph_init_params=graph_init_params, graph_runtime_state=runtime_state, - form_repository=form_repository, + runtime=runtime, ) pause_gen = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py index b0ed47158d..52802c7ce1 100644 --- a/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py +++ b/api/tests/unit_tests/core/workflow/nodes/human_input/test_human_input_form_filled_event.py @@ -1,18 +1,20 @@ import datetime from types import SimpleNamespace -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import ( +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes +from graphon.graph_events import ( NodeRunHumanInputFormFilledEvent, NodeRunHumanInputFormTimeoutEvent, NodeRunStartedEvent, ) -from dify_graph.nodes.human_input.enums import HumanInputFormStatus -from dify_graph.nodes.human_input.human_input_node import HumanInputNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from graphon.nodes.human_input.enums import HumanInputFormStatus +from graphon.nodes.human_input.human_input_node import HumanInputNode +from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom +from core.workflow.node_runtime import DifyHumanInputNodeRuntime +from core.workflow.system_variables import default_system_variables from libs.datetime_utils import naive_utc_now @@ -25,7 +27,7 @@ class _FakeFormRepository: def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name#}}") -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -85,11 +87,12 @@ def _build_node(form_content: str = "Please enter your name:\n\n{{#$output.name# graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) def _build_timeout_node() -> HumanInputNode: - system_variables = SystemVariable.default() + system_variables = default_system_variables() graph_runtime_state = GraphRuntimeState( variable_pool=VariablePool(system_variables=system_variables, user_inputs={}, environment_variables=[]), start_at=0.0, @@ -149,6 +152,7 @@ def _build_timeout_node() -> HumanInputNode: graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, form_repository=repo, + runtime=DifyHumanInputNodeRuntime(graph_init_params.run_context), ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py deleted file mode 100644 index 93c199514e..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/entities_spec.py +++ /dev/null @@ -1,339 +0,0 @@ -from dify_graph.nodes.iteration.entities import ( - ErrorHandleMode, - IterationNodeData, - IterationStartNodeData, - IterationState, -) - - -class TestErrorHandleMode: - """Test suite for ErrorHandleMode enum.""" - - def test_terminated_value(self): - """Test TERMINATED enum value.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.TERMINATED.value == "terminated" - - def test_continue_on_error_value(self): - """Test CONTINUE_ON_ERROR enum value.""" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - assert ErrorHandleMode.CONTINUE_ON_ERROR.value == "continue-on-error" - - def test_remove_abnormal_output_value(self): - """Test REMOVE_ABNORMAL_OUTPUT enum value.""" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT == "remove-abnormal-output" - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT.value == "remove-abnormal-output" - - def test_error_handle_mode_is_str_enum(self): - """Test ErrorHandleMode is a string enum.""" - assert isinstance(ErrorHandleMode.TERMINATED, str) - assert isinstance(ErrorHandleMode.CONTINUE_ON_ERROR, str) - assert isinstance(ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, str) - - def test_error_handle_mode_comparison(self): - """Test ErrorHandleMode can be compared with strings.""" - assert ErrorHandleMode.TERMINATED == "terminated" - assert ErrorHandleMode.CONTINUE_ON_ERROR == "continue-on-error" - - def test_all_error_handle_modes(self): - """Test all ErrorHandleMode values are accessible.""" - modes = list(ErrorHandleMode) - - assert len(modes) == 3 - assert ErrorHandleMode.TERMINATED in modes - assert ErrorHandleMode.CONTINUE_ON_ERROR in modes - assert ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT in modes - - -class TestIterationNodeData: - """Test suite for IterationNodeData model.""" - - def test_iteration_node_data_basic(self): - """Test IterationNodeData with basic configuration.""" - data = IterationNodeData( - title="Test Iteration", - iterator_selector=["node1", "output"], - output_selector=["iteration", "result"], - ) - - assert data.title == "Test Iteration" - assert data.iterator_selector == ["node1", "output"] - assert data.output_selector == ["iteration", "result"] - - def test_iteration_node_data_default_values(self): - """Test IterationNodeData default values.""" - data = IterationNodeData( - title="Default Test", - iterator_selector=["start", "items"], - output_selector=["iter", "out"], - ) - - assert data.parent_loop_id is None - assert data.is_parallel is False - assert data.parallel_nums == 10 - assert data.error_handle_mode == ErrorHandleMode.TERMINATED - assert data.flatten_output is True - - def test_iteration_node_data_parallel_mode(self): - """Test IterationNodeData with parallel mode enabled.""" - data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["node", "list"], - output_selector=["iter", "output"], - is_parallel=True, - parallel_nums=5, - ) - - assert data.is_parallel is True - assert data.parallel_nums == 5 - - def test_iteration_node_data_custom_parallel_nums(self): - """Test IterationNodeData with custom parallel numbers.""" - data = IterationNodeData( - title="Custom Parallel", - iterator_selector=["a", "b"], - output_selector=["c", "d"], - parallel_nums=20, - ) - - assert data.parallel_nums == 20 - - def test_iteration_node_data_continue_on_error(self): - """Test IterationNodeData with continue on error mode.""" - data = IterationNodeData( - title="Continue Error", - iterator_selector=["x", "y"], - output_selector=["z", "w"], - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - ) - - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_iteration_node_data_remove_abnormal_output(self): - """Test IterationNodeData with remove abnormal output mode.""" - data = IterationNodeData( - title="Remove Abnormal", - iterator_selector=["input", "array"], - output_selector=["output", "result"], - error_handle_mode=ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ) - - assert data.error_handle_mode == ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT - - def test_iteration_node_data_flatten_output_disabled(self): - """Test IterationNodeData with flatten output disabled.""" - data = IterationNodeData( - title="No Flatten", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data.flatten_output is False - - def test_iteration_node_data_with_parent_loop_id(self): - """Test IterationNodeData with parent loop ID.""" - data = IterationNodeData( - title="Nested Loop", - iterator_selector=["parent", "items"], - output_selector=["child", "output"], - parent_loop_id="parent_loop_123", - ) - - assert data.parent_loop_id == "parent_loop_123" - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex Selectors", - iterator_selector=["node1", "output", "data", "items"], - output_selector=["iteration", "result", "value"], - ) - - assert len(data.iterator_selector) == 4 - assert len(data.output_selector) == 3 - - def test_iteration_node_data_all_options(self): - """Test IterationNodeData with all options configured.""" - data = IterationNodeData( - title="Full Config", - iterator_selector=["start", "list"], - output_selector=["end", "result"], - parent_loop_id="outer_loop", - is_parallel=True, - parallel_nums=15, - error_handle_mode=ErrorHandleMode.CONTINUE_ON_ERROR, - flatten_output=False, - ) - - assert data.title == "Full Config" - assert data.parent_loop_id == "outer_loop" - assert data.is_parallel is True - assert data.parallel_nums == 15 - assert data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - assert data.flatten_output is False - - -class TestIterationStartNodeData: - """Test suite for IterationStartNodeData model.""" - - def test_iteration_start_node_data_basic(self): - """Test IterationStartNodeData basic creation.""" - data = IterationStartNodeData(title="Iteration Start") - - assert data.title == "Iteration Start" - - def test_iteration_start_node_data_with_description(self): - """Test IterationStartNodeData with description.""" - data = IterationStartNodeData( - title="Start Node", - desc="This is the start of iteration", - ) - - assert data.title == "Start Node" - assert data.desc == "This is the start of iteration" - - -class TestIterationState: - """Test suite for IterationState model.""" - - def test_iteration_state_default_values(self): - """Test IterationState default values.""" - state = IterationState() - - assert state.outputs == [] - assert state.current_output is None - - def test_iteration_state_with_outputs(self): - """Test IterationState with outputs.""" - state = IterationState(outputs=["result1", "result2", "result3"]) - - assert len(state.outputs) == 3 - assert state.outputs[0] == "result1" - assert state.outputs[2] == "result3" - - def test_iteration_state_with_current_output(self): - """Test IterationState with current output.""" - state = IterationState(current_output="current_value") - - assert state.current_output == "current_value" - - def test_iteration_state_get_last_output_with_outputs(self): - """Test get_last_output with outputs present.""" - state = IterationState(outputs=["first", "second", "last"]) - - result = state.get_last_output() - - assert result == "last" - - def test_iteration_state_get_last_output_empty(self): - """Test get_last_output with empty outputs.""" - state = IterationState(outputs=[]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_get_last_output_single(self): - """Test get_last_output with single output.""" - state = IterationState(outputs=["only_one"]) - - result = state.get_last_output() - - assert result == "only_one" - - def test_iteration_state_get_current_output(self): - """Test get_current_output method.""" - state = IterationState(current_output={"key": "value"}) - - result = state.get_current_output() - - assert result == {"key": "value"} - - def test_iteration_state_get_current_output_none(self): - """Test get_current_output when None.""" - state = IterationState() - - result = state.get_current_output() - - assert result is None - - def test_iteration_state_with_complex_outputs(self): - """Test IterationState with complex output types.""" - state = IterationState( - outputs=[ - {"id": 1, "name": "first"}, - {"id": 2, "name": "second"}, - [1, 2, 3], - "string_output", - ] - ) - - assert len(state.outputs) == 4 - assert state.outputs[0] == {"id": 1, "name": "first"} - assert state.outputs[2] == [1, 2, 3] - - def test_iteration_state_with_none_outputs(self): - """Test IterationState with None values in outputs.""" - state = IterationState(outputs=["value1", None, "value3"]) - - assert len(state.outputs) == 3 - assert state.outputs[1] is None - - def test_iteration_state_get_last_output_with_none(self): - """Test get_last_output when last output is None.""" - state = IterationState(outputs=["first", None]) - - result = state.get_last_output() - - assert result is None - - def test_iteration_state_metadata_class(self): - """Test IterationState.MetaData class.""" - metadata = IterationState.MetaData(iterator_length=10) - - assert metadata.iterator_length == 10 - - def test_iteration_state_metadata_different_lengths(self): - """Test IterationState.MetaData with different lengths.""" - metadata1 = IterationState.MetaData(iterator_length=0) - metadata2 = IterationState.MetaData(iterator_length=100) - metadata3 = IterationState.MetaData(iterator_length=1000000) - - assert metadata1.iterator_length == 0 - assert metadata2.iterator_length == 100 - assert metadata3.iterator_length == 1000000 - - def test_iteration_state_outputs_modification(self): - """Test modifying IterationState outputs.""" - state = IterationState(outputs=[]) - - state.outputs.append("new_output") - state.outputs.append("another_output") - - assert len(state.outputs) == 2 - assert state.get_last_output() == "another_output" - - def test_iteration_state_current_output_update(self): - """Test updating current_output.""" - state = IterationState() - - state.current_output = "first_value" - assert state.get_current_output() == "first_value" - - state.current_output = "updated_value" - assert state.get_current_output() == "updated_value" - - def test_iteration_state_with_numeric_outputs(self): - """Test IterationState with numeric outputs.""" - state = IterationState(outputs=[1, 2, 3, 4, 5]) - - assert state.get_last_output() == 5 - assert len(state.outputs) == 5 - - def test_iteration_state_with_boolean_outputs(self): - """Test IterationState with boolean outputs.""" - state = IterationState(outputs=[True, False, True]) - - assert state.get_last_output() is True - assert state.outputs[1] is False diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py deleted file mode 100644 index fdf5f4d1f8..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/iteration_node_spec.py +++ /dev/null @@ -1,438 +0,0 @@ -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.nodes.iteration.exc import ( - InvalidIteratorValueError, - IterationGraphNotFoundError, - IterationIndexNotFoundError, - IterationNodeError, - IteratorVariableNotFoundError, - StartNodeIdNotFoundError, -) -from dify_graph.nodes.iteration.iteration_node import IterationNode - - -class TestIterationNodeExceptions: - """Test suite for iteration node exceptions.""" - - def test_iteration_node_error_is_value_error(self): - """Test IterationNodeError inherits from ValueError.""" - error = IterationNodeError("test error") - - assert isinstance(error, ValueError) - assert str(error) == "test error" - - def test_iterator_variable_not_found_error(self): - """Test IteratorVariableNotFoundError.""" - error = IteratorVariableNotFoundError("Iterator variable not found") - - assert isinstance(error, IterationNodeError) - assert isinstance(error, ValueError) - assert "Iterator variable not found" in str(error) - - def test_invalid_iterator_value_error(self): - """Test InvalidIteratorValueError.""" - error = InvalidIteratorValueError("Invalid iterator value") - - assert isinstance(error, IterationNodeError) - assert "Invalid iterator value" in str(error) - - def test_start_node_id_not_found_error(self): - """Test StartNodeIdNotFoundError.""" - error = StartNodeIdNotFoundError("Start node ID not found") - - assert isinstance(error, IterationNodeError) - assert "Start node ID not found" in str(error) - - def test_iteration_graph_not_found_error(self): - """Test IterationGraphNotFoundError.""" - error = IterationGraphNotFoundError("Iteration graph not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration graph not found" in str(error) - - def test_iteration_index_not_found_error(self): - """Test IterationIndexNotFoundError.""" - error = IterationIndexNotFoundError("Iteration index not found") - - assert isinstance(error, IterationNodeError) - assert "Iteration index not found" in str(error) - - def test_exception_with_empty_message(self): - """Test exception with empty message.""" - error = IterationNodeError("") - - assert str(error) == "" - - def test_exception_with_detailed_message(self): - """Test exception with detailed message.""" - error = IteratorVariableNotFoundError("Variable 'items' not found in node 'start_node'") - - assert "items" in str(error) - assert "start_node" in str(error) - - def test_all_exceptions_inherit_from_base(self): - """Test all exceptions inherit from IterationNodeError.""" - exceptions = [ - IteratorVariableNotFoundError("test"), - InvalidIteratorValueError("test"), - StartNodeIdNotFoundError("test"), - IterationGraphNotFoundError("test"), - IterationIndexNotFoundError("test"), - ] - - for exc in exceptions: - assert isinstance(exc, IterationNodeError) - assert isinstance(exc, ValueError) - - -class TestIterationNodeClassAttributes: - """Test suite for IterationNode class attributes.""" - - def test_node_type(self): - """Test IterationNode node_type attribute.""" - assert IterationNode.node_type == BuiltinNodeTypes.ITERATION - - def test_version(self): - """Test IterationNode version method.""" - version = IterationNode.version() - - assert version == "1" - - -class TestIterationNodeDefaultConfig: - """Test suite for IterationNode get_default_config.""" - - def test_get_default_config_returns_dict(self): - """Test get_default_config returns a dictionary.""" - config = IterationNode.get_default_config() - - assert isinstance(config, dict) - - def test_get_default_config_type(self): - """Test get_default_config includes type.""" - config = IterationNode.get_default_config() - - assert config.get("type") == "iteration" - - def test_get_default_config_has_config_section(self): - """Test get_default_config has config section.""" - config = IterationNode.get_default_config() - - assert "config" in config - assert isinstance(config["config"], dict) - - def test_get_default_config_is_parallel_default(self): - """Test get_default_config is_parallel default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["is_parallel"] is False - - def test_get_default_config_parallel_nums_default(self): - """Test get_default_config parallel_nums default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["parallel_nums"] == 10 - - def test_get_default_config_error_handle_mode_default(self): - """Test get_default_config error_handle_mode default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["error_handle_mode"] == ErrorHandleMode.TERMINATED - - def test_get_default_config_flatten_output_default(self): - """Test get_default_config flatten_output default value.""" - config = IterationNode.get_default_config() - - assert config["config"]["flatten_output"] is True - - def test_get_default_config_with_none_filters(self): - """Test get_default_config with None filters.""" - config = IterationNode.get_default_config(filters=None) - - assert config is not None - assert "type" in config - - def test_get_default_config_with_empty_filters(self): - """Test get_default_config with empty filters.""" - config = IterationNode.get_default_config(filters={}) - - assert config is not None - - -class TestIterationNodeInitialization: - """Test suite for IterationNode initialization.""" - - def test_init_node_data_basic(self): - """Test init_node_data with basic configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Test Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - } - - node.init_node_data(data) - - assert node._node_data.title == "Test Iteration" - assert node._node_data.iterator_selector == ["start", "items"] - - def test_init_node_data_with_parallel(self): - """Test init_node_data with parallel configuration.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Parallel Iteration", - "iterator_selector": ["node", "list"], - "output_selector": ["out", "result"], - "is_parallel": True, - "parallel_nums": 5, - } - - node.init_node_data(data) - - assert node._node_data.is_parallel is True - assert node._node_data.parallel_nums == 5 - - def test_init_node_data_with_error_handle_mode(self): - """Test init_node_data with error handle mode.""" - node = IterationNode.__new__(IterationNode) - data = { - "title": "Error Handle Test", - "iterator_selector": ["a", "b"], - "output_selector": ["c", "d"], - "error_handle_mode": "continue-on-error", - } - - node.init_node_data(data) - - assert node._node_data.error_handle_mode == ErrorHandleMode.CONTINUE_ON_ERROR - - def test_get_title(self): - """Test _get_title method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="My Iteration", - iterator_selector=["x"], - output_selector=["y"], - ) - - assert node._get_title() == "My Iteration" - - def test_get_description_none(self): - """Test _get_description returns None when not set.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() is None - - def test_get_description_with_value(self): - """Test _get_description with value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - desc="This is a description", - iterator_selector=["a"], - output_selector=["b"], - ) - - assert node._get_description() == "This is a description" - - def test_node_data_property(self): - """Test node_data property returns node data.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Base Test", - iterator_selector=["x"], - output_selector=["y"], - ) - - result = node.node_data - - assert result == node._node_data - - -class TestIterationNodeDataValidation: - """Test suite for IterationNodeData validation scenarios.""" - - def test_valid_iteration_node_data(self): - """Test valid IterationNodeData creation.""" - data = IterationNodeData( - title="Valid Iteration", - iterator_selector=["start", "items"], - output_selector=["end", "result"], - ) - - assert data.title == "Valid Iteration" - - def test_iteration_node_data_with_all_error_modes(self): - """Test IterationNodeData with all error handle modes.""" - modes = [ - ErrorHandleMode.TERMINATED, - ErrorHandleMode.CONTINUE_ON_ERROR, - ErrorHandleMode.REMOVE_ABNORMAL_OUTPUT, - ] - - for mode in modes: - data = IterationNodeData( - title=f"Test {mode}", - iterator_selector=["a"], - output_selector=["b"], - error_handle_mode=mode, - ) - assert data.error_handle_mode == mode - - def test_iteration_node_data_parallel_configuration(self): - """Test IterationNodeData parallel configuration combinations.""" - configs = [ - (False, 10), - (True, 1), - (True, 5), - (True, 20), - (True, 100), - ] - - for is_parallel, parallel_nums in configs: - data = IterationNodeData( - title="Parallel Test", - iterator_selector=["x"], - output_selector=["y"], - is_parallel=is_parallel, - parallel_nums=parallel_nums, - ) - assert data.is_parallel == is_parallel - assert data.parallel_nums == parallel_nums - - def test_iteration_node_data_flatten_output_options(self): - """Test IterationNodeData flatten_output options.""" - data_flatten = IterationNodeData( - title="Flatten True", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=True, - ) - - data_no_flatten = IterationNodeData( - title="Flatten False", - iterator_selector=["a"], - output_selector=["b"], - flatten_output=False, - ) - - assert data_flatten.flatten_output is True - assert data_no_flatten.flatten_output is False - - def test_iteration_node_data_complex_selectors(self): - """Test IterationNodeData with complex selectors.""" - data = IterationNodeData( - title="Complex", - iterator_selector=["node1", "output", "data", "items", "list"], - output_selector=["iteration", "result", "value", "final"], - ) - - assert len(data.iterator_selector) == 5 - assert len(data.output_selector) == 4 - - def test_iteration_node_data_single_element_selectors(self): - """Test IterationNodeData with single element selectors.""" - data = IterationNodeData( - title="Single", - iterator_selector=["items"], - output_selector=["result"], - ) - - assert len(data.iterator_selector) == 1 - assert len(data.output_selector) == 1 - - -class TestIterationNodeErrorStrategies: - """Test suite for IterationNode error strategies.""" - - def test_get_error_strategy_default(self): - """Test _get_error_strategy with default value.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_error_strategy() - - assert result is None or result == node._node_data.error_strategy - - def test_get_retry_config(self): - """Test _get_retry_config method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_retry_config() - - assert result is not None - - def test_get_default_value_dict(self): - """Test _get_default_value_dict method.""" - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Test", - iterator_selector=["a"], - output_selector=["b"], - ) - - result = node._get_default_value_dict() - - assert isinstance(result, dict) - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "iteration_id": "iteration-node", - }, - } - - IterationNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "iteration-node", - "data": { - "type": "iteration", - "title": "Iteration", - "iterator_selector": ["start", "items"], - "output_selector": ["iteration", "result"], - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="iteration-node", - node_data=IterationNodeData( - title="Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "result"], - ), - ) - - assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py index 2eb4feef5f..bbfe350f7e 100644 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py +++ b/api/tests/unit_tests/core/workflow/nodes/iteration/test_iteration_child_engine_errors.py @@ -1,18 +1,18 @@ -from collections.abc import Mapping, Sequence +from collections.abc import Mapping from typing import Any import pytest - -from dify_graph.entities import GraphInitParams -from dify_graph.nodes.iteration.exc import IterationGraphNotFoundError -from dify_graph.nodes.iteration.iteration_node import IterationNode -from dify_graph.runtime import ( +from graphon.entities import GraphInitParams +from graphon.nodes.iteration.exc import IterationGraphNotFoundError +from graphon.nodes.iteration.iteration_node import IterationNode +from graphon.runtime import ( ChildEngineBuilderNotConfiguredError, ChildGraphNotFoundError, GraphRuntimeState, VariablePool, ) -from dify_graph.system_variable import SystemVariable + +from core.workflow.system_variables import default_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -22,17 +22,16 @@ class _MissingGraphBuilder: *, workflow_id: str, graph_init_params: GraphInitParams, - graph_runtime_state: GraphRuntimeState, - graph_config: Mapping[str, Any], + parent_graph_runtime_state: GraphRuntimeState, root_node_id: str, - layers: Sequence[object] = (), + variable_pool: VariablePool | None = None, ) -> object: raise ChildGraphNotFoundError(f"child graph root node '{root_node_id}' not found") def _build_runtime_state() -> GraphRuntimeState: return GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable.default(), user_inputs={}), + variable_pool=VariablePool(system_variables=default_system_variables(), user_inputs={}), start_at=0.0, ) @@ -69,8 +68,6 @@ def test_graph_runtime_state_raises_specific_error_when_child_builder_is_missing runtime_state.create_child_engine( workflow_id="workflow", graph_init_params=graph_init_params, - graph_runtime_state=_build_runtime_state(), - graph_config={}, root_node_id="root", ) diff --git a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py b/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py deleted file mode 100644 index 8660449032..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/iteration/test_parallel_iteration_duration.py +++ /dev/null @@ -1,63 +0,0 @@ -import time -from contextlib import nullcontext -from datetime import UTC, datetime - -import pytest - -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.graph_events import NodeRunSucceededEvent -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.nodes.iteration.entities import ErrorHandleMode, IterationNodeData -from dify_graph.nodes.iteration.iteration_node import IterationNode - - -def test_parallel_iteration_duration_map_uses_worker_measured_time() -> None: - node = IterationNode.__new__(IterationNode) - node._node_data = IterationNodeData( - title="Parallel Iteration", - iterator_selector=["start", "items"], - output_selector=["iteration", "output"], - is_parallel=True, - parallel_nums=2, - error_handle_mode=ErrorHandleMode.TERMINATED, - ) - node._capture_execution_context = lambda: nullcontext() - node._sync_conversation_variables_from_snapshot = lambda snapshot: None - node._merge_usage = lambda current, new: new if current.total_tokens == 0 else current.plus(new) - - def fake_execute_single_iteration_parallel(*, index: int, item: object, execution_context: object): - return ( - 0.1 + (index * 0.1), - [ - NodeRunSucceededEvent( - id=f"exec-{index}", - node_id=f"llm-{index}", - node_type=BuiltinNodeTypes.LLM, - start_at=datetime.now(UTC).replace(tzinfo=None), - ), - ], - f"output-{item}", - {}, - LLMUsage.empty_usage(), - ) - - node._execute_single_iteration_parallel = fake_execute_single_iteration_parallel - - outputs: list[object] = [] - iter_run_map: dict[str, float] = {} - usage_accumulator = [LLMUsage.empty_usage()] - - generator = node._execute_parallel_iterations( - iterator_list_value=["a", "b"], - outputs=outputs, - iter_run_map=iter_run_map, - usage_accumulator=usage_accumulator, - ) - - for _ in generator: - # Simulate a slow consumer replaying buffered events. - time.sleep(0.02) - - assert outputs == ["output-a", "output-b"] - assert iter_run_map["0"] == pytest.approx(0.1) - assert iter_run_map["1"] == pytest.approx(0.2) diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py index feb560bbc3..f8802138b5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_index/test_knowledge_index_node.py @@ -3,6 +3,9 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.rag.index_processor.constant.index_type import IndexTechniqueType @@ -15,10 +18,7 @@ from core.workflow.nodes.knowledge_index.protocols import ( PreviewItem, SummaryIndexServiceProtocol, ) -from dify_graph.enums import SystemVariableKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import StringSegment +from core.workflow.system_variables import SystemVariableKey, build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -41,7 +41,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], diff --git a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py index 99997db6b2..ab64be59ad 100644 --- a/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/knowledge_retrieval/test_knowledge_retrieval_node.py @@ -3,6 +3,10 @@ import uuid from unittest.mock import Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import StringSegment from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from core.workflow.nodes.knowledge_retrieval.entities import ( @@ -16,11 +20,7 @@ from core.workflow.nodes.knowledge_retrieval.entities import ( from core.workflow.nodes.knowledge_retrieval.exc import RateLimitExceededError from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode from core.workflow.nodes.knowledge_retrieval.retrieval import RAGRetrievalProtocol, Source -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import StringSegment +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -43,7 +43,7 @@ def mock_graph_init_params(): def mock_graph_runtime_state(): """Create mock GraphRuntimeState.""" variable_pool = VariablePool( - system_variables=SystemVariable(user_id=str(uuid.uuid4()), files=[]), + system_variables=build_system_variables(user_id=str(uuid.uuid4()), files=[]), user_inputs={}, environment_variables=[], conversation_variables=[], @@ -157,7 +157,7 @@ class TestKnowledgeRetrievalNode: ): """Test _run with query variable in single mode.""" # Arrange - from dify_graph.nodes.llm.entities import ModelConfig + from graphon.nodes.llm.entities import ModelConfig query = "What is Python?" query_selector = ["start", "query"] @@ -441,7 +441,7 @@ class TestFetchDatasetRetriever: ): """Test _fetch_dataset_retriever in single mode.""" # Arrange - from dify_graph.nodes.llm.entities import ModelConfig + from graphon.nodes.llm.entities import ModelConfig query = "What is Python?" variables = {"query": query} diff --git a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py index d71e0921c1..fdf1706765 100644 --- a/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/list_operator/node_spec.py @@ -1,13 +1,13 @@ from unittest.mock import MagicMock import pytest +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.nodes.list_operator.node import ListOperatorNode +from graphon.runtime import GraphRuntimeState +from graphon.variables import ArrayNumberSegment, ArrayStringSegment -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.nodes.list_operator.node import ListOperatorNode -from dify_graph.runtime import GraphRuntimeState -from dify_graph.variables import ArrayNumberSegment, ArrayStringSegment +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY class TestListOperatorNode: diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py deleted file mode 100644 index b0f0fd428b..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_file_saver.py +++ /dev/null @@ -1,196 +0,0 @@ -import uuid -from typing import NamedTuple -from unittest import mock -from unittest.mock import MagicMock - -import httpx -import pytest - -from core.helper import ssrf_proxy -from core.tools import signature -from core.tools.tool_file_manager import ToolFileManager -from dify_graph.file import FileTransferMethod, FileType, models -from dify_graph.nodes.llm.file_saver import ( - FileSaverImpl, - _extract_content_type_and_extension, - _get_extension, - _validate_extension_override, -) -from models import ToolFile - -_PNG_DATA = b"\x89PNG\r\n\x1a\n" - - -def _gen_id(): - return str(uuid.uuid4()) - - -class TestFileSaverImpl: - def test_save_binary_string(self, monkeypatch: pytest.MonkeyPatch): - user_id = _gen_id() - tenant_id = _gen_id() - file_type = FileType.IMAGE - mime_type = "image/png" - mock_signed_url = "https://example.com/image.png" - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), - ) - mock_tool_file.id = _gen_id() - mocked_tool_file_manager = mock.MagicMock(spec=ToolFileManager) - - mocked_tool_file_manager.create_file_by_raw.return_value = mock_tool_file - monkeypatch.setattr(FileSaverImpl, "_get_tool_file_manager", lambda _: mocked_tool_file_manager) - # Since `File.generate_url` used `ToolFileManager.sign_file` directly, we also need to patch it here. - mocked_sign_file = mock.MagicMock(spec=signature.sign_tool_file) - # Since `File.generate_url` used `signature.sign_tool_file` directly, we also need to patch it here. - monkeypatch.setattr(models, "sign_tool_file", mocked_sign_file) - mocked_sign_file.return_value = mock_signed_url - http_client = MagicMock() - - storage_file_manager = FileSaverImpl( - user_id=user_id, - tenant_id=tenant_id, - http_client=http_client, - ) - - file = storage_file_manager.save_binary_string(_PNG_DATA, mime_type, file_type) - assert file.tenant_id == tenant_id - assert file.type == file_type - assert file.transfer_method == FileTransferMethod.TOOL_FILE - assert file.extension == ".png" - assert file.mime_type == mime_type - assert file.size == len(_PNG_DATA) - assert file.related_id == mock_tool_file.id - - assert file.generate_url() == mock_signed_url - - mocked_tool_file_manager.create_file_by_raw.assert_called_once_with( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_binary=_PNG_DATA, - mimetype=mime_type, - ) - mocked_sign_file.assert_called_once_with(tool_file_id=mock_tool_file.id, extension=".png", for_external=True) - - def test_save_remote_url_request_failed(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=401, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl( - user_id=_gen_id(), - tenant_id=_gen_id(), - http_client=http_client, - ) - - with pytest.raises(httpx.HTTPStatusError) as exc: - file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - http_client.get.assert_called_once_with(_TEST_URL) - assert exc.value.response.status_code == 401 - - def test_save_remote_url_success(self, monkeypatch: pytest.MonkeyPatch): - _TEST_URL = "https://example.com/image.png" - mime_type = "image/png" - user_id = _gen_id() - tenant_id = _gen_id() - - mock_request = httpx.Request("GET", _TEST_URL) - mock_response = httpx.Response( - status_code=200, - content=b"test-data", - headers={"Content-Type": mime_type}, - request=mock_request, - ) - http_client = MagicMock() - http_client.get.return_value = mock_response - - file_saver = FileSaverImpl(user_id=user_id, tenant_id=tenant_id, http_client=http_client) - mock_tool_file = ToolFile( - user_id=user_id, - tenant_id=tenant_id, - conversation_id=None, - file_key="test-file-key", - mimetype=mime_type, - original_url=None, - name=f"{_gen_id()}.png", - size=len(_PNG_DATA), - ) - mock_tool_file.id = _gen_id() - mock_get = mock.MagicMock(spec=ssrf_proxy.get, return_value=mock_response) - monkeypatch.setattr(ssrf_proxy, "get", mock_get) - mock_save_binary_string = mock.MagicMock(spec=file_saver.save_binary_string, return_value=mock_tool_file) - monkeypatch.setattr(file_saver, "save_binary_string", mock_save_binary_string) - - file = file_saver.save_remote_url(_TEST_URL, FileType.IMAGE) - mock_save_binary_string.assert_called_once_with( - mock_response.content, - mime_type, - FileType.IMAGE, - extension_override=".png", - ) - assert file == mock_tool_file - - -def test_validate_extension_override(): - class TestCase(NamedTuple): - extension_override: str | None - expected: str | None - - cases = [TestCase(None, None), TestCase("", ""), ".png", ".png", ".tar.gz", ".tar.gz"] - - for valid_ext_override in [None, "", ".png", ".tar.gz"]: - assert valid_ext_override == _validate_extension_override(valid_ext_override) - - for invalid_ext_override in ["png", "tar.gz"]: - with pytest.raises(ValueError) as exc: - _validate_extension_override(invalid_ext_override) - - -class TestExtractContentTypeAndExtension: - def test_with_both_content_type_and_extension(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image.jpg", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_url_with_file_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image.png", content_type) - assert content_type == "image/png" - assert extension == ".png" - - def test_response_with_content_type(self): - content_type, extension = _extract_content_type_and_extension("https://example.com/image", "image/png") - assert content_type == "image/png" - assert extension == ".png" - - def test_no_content_type_and_no_extension(self): - for content_type in [None, ""]: - content_type, extension = _extract_content_type_and_extension("https://example.com/image", content_type) - assert content_type == "application/octet-stream" - assert extension == ".bin" - - -class TestGetExtension: - def test_with_extension_override(self): - mime_type = "image/png" - for override in [".jpg", ""]: - extension = _get_extension(mime_type, override) - assert extension == override - - def test_without_extension_override(self): - mime_type = "image/png" - extension = _get_extension(mime_type) - assert extension == ".png" diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py index acecbf4944..c784f805c0 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py @@ -1,18 +1,86 @@ from unittest import mock import pytest - -from core.model_manager import ModelInstance -from dify_graph.model_runtime.entities import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities import ( ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent, ) -from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage -from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage -from dify_graph.nodes.llm.exc import NoPromptFoundError -from dify_graph.runtime import VariablePool +from graphon.model_runtime.entities.message_entities import ( + AssistantPromptMessage, + SystemPromptMessage, + UserPromptMessage, +) +from graphon.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.llm import llm_utils +from graphon.nodes.llm.entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig +from graphon.nodes.llm.exc import ( + InvalidVariableTypeError, + MemoryRolePrefixRequiredError, + NoPromptFoundError, + TemplateTypeNotSupportError, +) +from graphon.runtime import VariablePool +from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + +from core.model_manager import ModelInstance + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label={"en_US": "GPT-3.5 Turbo"}, + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_model_instance(*, model_schema: AIModelEntity | None = None) -> mock.MagicMock: + model_instance = mock.MagicMock(spec=ModelInstance) + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.get_model_schema.return_value = model_schema or _build_model_schema(features=[]) + model_instance.get_llm_num_tokens.return_value = 0 + return model_instance + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) @pytest.fixture @@ -37,15 +105,15 @@ def _fetch_prompt_messages_with_mocked_content(content): with ( mock.patch( - "dify_graph.nodes.llm.llm_utils.fetch_model_schema", + "graphon.nodes.llm.llm_utils.fetch_model_schema", return_value=mock.MagicMock(features=[]), ), mock.patch( - "dify_graph.nodes.llm.llm_utils.handle_list_messages", + "graphon.nodes.llm.llm_utils.handle_list_messages", return_value=[SystemPromptMessage(content=content)], ), mock.patch( - "dify_graph.nodes.llm.llm_utils.handle_memory_chat_mode", + "graphon.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[], ), ): @@ -270,3 +338,700 @@ def test_fetch_prompt_messages_keeps_list_content_when_multiple_supported_items_ ] ) ] + + +def test_fetch_model_schema_raises_when_model_schema_is_missing(): + model_instance = _build_model_instance() + model_instance.get_model_schema.return_value = None + + with pytest.raises(ValueError, match="Model schema not found for gpt-3.5-turbo"): + llm_utils.fetch_model_schema(model_instance=model_instance) + + +def test_fetch_files_supports_known_segments_and_rejects_invalid_types(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + variable_pool = VariablePool.empty() + variable_pool.add(["input", "file"], file) + variable_pool.add(["input", "files"], ArrayFileSegment(value=[file])) + variable_pool.add(["input", "none"], NoneSegment()) + variable_pool.add(["input", "empty"], ArrayAnySegment(value=[])) + variable_pool.add(["input", "invalid"], {"a": 1}) + + assert llm_utils.fetch_files(variable_pool, ["input", "file"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "files"]) == [file] + assert llm_utils.fetch_files(variable_pool, ["input", "none"]) == [] + assert llm_utils.fetch_files(variable_pool, ["input", "empty"]) == [] + + with pytest.raises(InvalidVariableTypeError, match="Invalid variable type"): + llm_utils.fetch_files(variable_pool, ["input", "invalid"]) + + +def test_fetch_files_returns_empty_for_missing_variable(): + assert llm_utils.fetch_files(VariablePool.empty(), ["input", "missing"]) == [] + + +def test_convert_history_messages_to_text_skips_system_messages_and_formats_images(): + history_text = llm_utils.convert_history_messages_to_text( + history_messages=[ + SystemPromptMessage(content="skip"), + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Question"), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ), + AssistantPromptMessage(content="Answer"), + ], + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert history_text == "Human: Question\n[image]\nAssistant: Answer" + + +def test_fetch_memory_text_uses_prompt_memory_interface(): + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + memory_text = llm_utils.fetch_memory_text( + memory=memory, + max_token_limit=321, + message_limit=2, + human_prefix="Human", + ai_prefix="Assistant", + ) + + assert memory_text == "Human: Question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_handle_list_messages_renders_jinja2_messages(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + template_renderer=renderer, + ) + + assert prompt_messages == [SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")])] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_handle_list_messages_splits_text_and_file_content(): + variable_pool = VariablePool.empty() + image_file = _build_image_file( + file_id="image-file", + related_id="image-related", + remote_url="https://example.com/file.png", + ) + variable_pool.add(["input", "image"], image_file) + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ) as mock_to_prompt: + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="Analyze {{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Analyze ")]), + UserPromptMessage( + content=[ + ImagePromptMessageContent( + format="png", + url="https://example.com/file.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + ] + ), + ] + mock_to_prompt.assert_called_once() + + +def test_handle_list_messages_supports_array_file_segments(): + variable_pool = VariablePool.empty() + first_file = _build_image_file(file_id="first", related_id="first-related", remote_url="https://example.com/1.png") + second_file = _build_image_file( + file_id="second", + related_id="second-related", + remote_url="https://example.com/2.png", + ) + variable_pool.add(["input", "images"], ArrayFileSegment(value=[first_file, second_file])) + + first_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/1.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + second_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/2.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=[first_prompt, second_prompt], + ): + prompt_messages = llm_utils.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="{{#input.images#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + context="", + jinja2_variables=[], + variable_pool=variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [UserPromptMessage(content=[first_prompt, second_prompt])] + + +def test_render_jinja2_message_handles_empty_template_success_and_missing_renderer(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + assert ( + llm_utils.render_jinja2_message( + template="", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + == "" + ) + + with pytest.raises(ValueError, match="template_renderer is required"): + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + assert ( + llm_utils.render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + template_renderer=renderer, + ) + == "Hello Dify" + ) + + +def test_handle_completion_template_supports_basic_and_jinja2_templates(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + basic_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize {{#context#}}", + edition_type="basic", + ), + context="the docs", + jinja2_variables=[], + variable_pool=variable_pool, + ) + jinja_messages = llm_utils.handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="Hello {{ name }}", + edition_type="jinja2", + ), + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=variable_pool, + template_renderer=renderer, + ) + + assert basic_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Summarize the docs")]), + ] + assert jinja_messages == [ + UserPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + + +def test_combine_message_content_with_role_handles_all_supported_roles(): + contents = [TextPromptMessageContent(data="hello")] + + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.USER) == ( + UserPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.ASSISTANT) == ( + AssistantPromptMessage(content=contents) + ) + assert llm_utils.combine_message_content_with_role(contents=contents, role=PromptMessageRole.SYSTEM) == ( + SystemPromptMessage(content=contents) + ) + + with pytest.raises(NotImplementedError, match="Role custom is not supported"): + llm_utils.combine_message_content_with_role(contents=contents, role="custom") # type: ignore[arg-type] + + +def test_calculate_rest_token_uses_context_size_and_template_alias(): + model_instance = _build_model_instance( + model_schema=_build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="output_limit", + use_template="max_tokens", + label={"en_US": "Output Limit"}, + type=ParameterType.INT, + ) + ], + ) + ) + model_instance.parameters = {"max_tokens": 512} + model_instance.get_llm_num_tokens.return_value = 256 + + assert ( + llm_utils.calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 3328 + ) + + +def test_handle_memory_chat_mode_returns_empty_without_memory_and_uses_window_when_present(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [UserPromptMessage(content="Question")] + + assert ( + llm_utils.handle_memory_chat_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == [] + ) + + with mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=123) as mock_rest: + messages = llm_utils.handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + assert messages == [UserPromptMessage(content="Question")] + mock_rest.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=123, message_limit=2) + + +def test_handle_memory_completion_mode_validates_role_prefix_and_formats_history(): + model_instance = _build_model_instance() + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="Question"), + AssistantPromptMessage(content="Answer"), + ] + + assert ( + llm_utils.handle_memory_completion_mode( + memory=None, + memory_config=None, + model_instance=model_instance, + ) + == "" + ) + + with ( + mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=456), + pytest.raises(MemoryRolePrefixRequiredError, match="Memory role prefix is required"), + ): + llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=model_instance, + ) + + with mock.patch("graphon.nodes.llm.llm_utils.calculate_rest_token", return_value=456): + history_text = llm_utils.handle_memory_completion_mode( + memory=memory, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=False), + ), + model_instance=model_instance, + ) + + assert history_text == "Human: Question\nAssistant: Answer" + memory.get_history_prompt_messages.assert_called_with(max_token_limit=456, message_limit=None) + + +def test_append_file_prompts_merges_with_existing_user_content_or_appends_new_message(): + file = _build_image_file(file_id="image", related_id="image-related", remote_url="https://example.com/image.png") + file_prompt = ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + prompt_messages = [UserPromptMessage(content=[TextPromptMessageContent(data="Question")])] + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages == [ + UserPromptMessage(content=[file_prompt, TextPromptMessageContent(data="Question")]), + ] + + prompt_messages = [SystemPromptMessage(content="System prompt")] + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + return_value=file_prompt, + ): + llm_utils._append_file_prompts( + prompt_messages=prompt_messages, + files=[file], + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + ) + + assert prompt_messages[-1] == UserPromptMessage(content=[file_prompt]) + + +def test_fetch_prompt_messages_chat_mode_includes_query_memory_and_supported_files(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.VISION])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history")] + sys_file = _build_image_file(file_id="sys", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + file_prompts = [ + ImagePromptMessageContent( + format="png", + url="https://example.com/sys.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + format="png", + url="https://example.com/context.png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch( + "graphon.nodes.llm.llm_utils.file_manager.to_prompt_message_content", + side_effect=file_prompts, + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=False)), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history") + assert prompt_messages[2] == UserPromptMessage( + content=[ + file_prompts[1], + file_prompts[0], + TextPromptMessageContent(data="current question"), + ] + ) + + +def test_fetch_prompt_messages_completion_mode_updates_list_content_with_histories_and_query(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + memory = mock.MagicMock() + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="another question"), + AssistantPromptMessage(content="another answer"), + ] + + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header", + edition_type="basic", + ), + stop=None, + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [ + UserPromptMessage(content="latest question\nHuman: another question\nAssistant: another answer\nPrompt header") + ] + + +def test_fetch_prompt_messages_filters_content_unsupported_by_model_features(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[ModelFeature.DOCUMENT])) + prompt_template = [ + LLMNodeChatModelMessage( + text="You are a classifier.", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ) + ] + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_list_messages", + return_value=[ + SystemPromptMessage( + content=[ + TextPromptMessageContent(data="You are a classifier."), + ImagePromptMessageContent( + format="png", + url="https://example.com/image.png", + mime_type="image/png", + ), + ] + ) + ], + ), + mock.patch("graphon.nodes.llm.llm_utils.handle_memory_chat_mode", return_value=[]), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=prompt_template, + stop=("END",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("END",) + assert prompt_messages == [SystemPromptMessage(content="You are a classifier.")] + + +def test_fetch_prompt_messages_completion_mode_supports_string_content_and_invalid_template_type(): + model_instance = _build_model_instance(model_schema=_build_model_schema(features=[])) + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix #histories# and #sys.query#")], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, stop = llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=("HALT",), + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [UserPromptMessage(content="Prefix history text and latest question")] + + with pytest.raises(TemplateTypeNotSupportError): + llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=object(), # type: ignore[arg-type] + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + invalid_prompt = mock.MagicMock() + invalid_prompt.content = object() + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[invalid_prompt], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + pytest.raises(ValueError, match="Invalid prompt content type"), + ): + llm_utils.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + with ( + mock.patch( + "graphon.nodes.llm.llm_utils.handle_completion_template", + return_value=[UserPromptMessage(content="Prefix only")], + ), + mock.patch( + "graphon.nodes.llm.llm_utils.handle_memory_completion_mode", + return_value="history text", + ), + ): + prompt_messages, _ = llm_utils.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="ignored", + edition_type="basic", + ), + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert prompt_messages == [UserPromptMessage(content="history text\nPrefix only")] diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py index fc96088af1..a215e9d350 100644 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/llm/test_node.py @@ -4,41 +4,81 @@ from collections.abc import Sequence from unittest import mock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom -from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config -from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle -from core.entities.provider_entities import CustomConfiguration, SystemConfiguration -from core.model_manager import ModelInstance -from core.prompt.entities.advanced_prompt_entities import MemoryConfig -from dify_graph.entities import GraphInitParams -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.entities import GraphInitParams +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.llm_entities import ( + LLMResultChunk, + LLMResultChunkDelta, + LLMResultChunkWithStructuredOutput, + LLMResultWithStructuredOutput, + LLMUsage, +) +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, ImagePromptMessageContent, PromptMessage, PromptMessageRole, + SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelType -from dify_graph.model_runtime.model_providers.model_provider_factory import ModelProviderFactory -from dify_graph.nodes.llm import llm_utils -from dify_graph.nodes.llm.entities import ( +from graphon.model_runtime.entities.model_entities import ( + AIModelEntity, + FetchFrom, + ModelFeature, + ModelPropertyKey, + ModelType, + ParameterRule, + ParameterType, +) +from graphon.model_runtime.model_providers.model_provider_factory import ModelProviderFactory +from graphon.node_events import ModelInvokeCompletedEvent, RunRetrieverResourceEvent, StreamChunkEvent +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.llm import llm_utils +from graphon.nodes.llm.entities import ( ContextConfig, LLMNodeChatModelMessage, + LLMNodeCompletionModelPromptTemplate, LLMNodeData, ModelConfig, + PromptConfig, VisionConfig, VisionConfigOptions, ) -from dify_graph.nodes.llm.file_saver import LLMFileSaver -from dify_graph.nodes.llm.node import LLMNode -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment +from graphon.nodes.llm.exc import ( + InvalidContextStructureError, + LLMNodeError, + NoPromptFoundError, + VariableNotFoundError, +) +from graphon.nodes.llm.file_saver import LLMFileSaver +from graphon.nodes.llm.node import ( + LLMNode, + _calculate_rest_token, + _handle_completion_template, + _handle_memory_chat_mode, + _handle_memory_completion_mode, + _render_jinja2_message, +) +from graphon.nodes.llm.protocols import CredentialsProvider, ModelFactory +from graphon.nodes.llm.runtime_protocols import PromptMessageSerializerProtocol +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.template_rendering import TemplateRenderError +from graphon.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment + +from core.app.entities.app_invoke_entities import DifyRunContext, InvokeFrom, ModelConfigWithCredentialsEntity, UserFrom +from core.app.llm.model_access import ( + DifyCredentialsProvider, + DifyModelFactory, + build_dify_model_access, + fetch_model_config, +) +from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle +from core.entities.provider_entities import CustomConfiguration, SystemConfiguration +from core.plugin.impl.model_runtime_factory import create_plugin_model_runtime +from core.prompt.entities.advanced_prompt_entities import MemoryConfig +from core.workflow.system_variables import default_system_variables from models.provider import ProviderType from tests.workflow_test_utils import build_test_graph_init_params @@ -55,6 +95,62 @@ class MockTokenBufferMemory: return self.history_messages +def _build_prepared_llm_mock() -> mock.MagicMock: + model_instance = mock.MagicMock() + model_instance.provider = "openai" + model_instance.model_name = "gpt-3.5-turbo" + model_instance.parameters = {} + model_instance.stop = () + model_instance.get_llm_num_tokens.return_value = 0 + model_instance.get_model_schema.return_value = AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + model_properties={}, + ) + model_instance.is_structured_output_parse_error.return_value = False + return model_instance + + +def _build_model_schema( + *, + features: list[ModelFeature] | None = None, + model_properties: dict[ModelPropertyKey, object] | None = None, + parameter_rules: list[ParameterRule] | None = None, +) -> AIModelEntity: + return AIModelEntity( + model="gpt-3.5-turbo", + label=I18nObject(en_US="GPT-3.5 Turbo"), + model_type=ModelType.LLM, + fetch_from=FetchFrom.CUSTOMIZABLE_MODEL, + features=features, + model_properties=model_properties or {}, + parameter_rules=parameter_rules or [], + ) + + +def _build_image_file( + *, + file_id: str, + related_id: str, + remote_url: str, + extension: str = ".png", + mime_type: str = "image/png", +) -> File: + return File( + id=file_id, + type=FileType.IMAGE, + filename=f"{file_id}{extension}", + transfer_method=FileTransferMethod.REMOTE_URL, + remote_url=remote_url, + related_id=related_id, + extension=extension, + mime_type=mime_type, + storage_key="", + ) + + @pytest.fixture def llm_node_data() -> LLMNodeData: return LLMNodeData( @@ -91,7 +187,7 @@ def graph_init_params() -> GraphInitParams: @pytest.fixture def graph_runtime_state() -> GraphRuntimeState: variable_pool = VariablePool( - system_variables=SystemVariable.default(), + system_variables=default_system_variables(), user_inputs={}, ) return GraphRuntimeState( @@ -107,7 +203,7 @@ def llm_node( mock_file_saver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -120,9 +216,9 @@ def llm_node( graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node @@ -132,28 +228,31 @@ def llm_node( def model_config(monkeypatch): from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass - def mock_plugin_model_providers(_self): - providers = MockModelClass().fetch_model_providers("test") - for provider in providers: - provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}" + def mock_model_providers(_self): + providers = [] + for provider in MockModelClass().fetch_model_providers("test"): + provider_schema = provider.declaration.model_copy(deep=True) + provider_schema.provider = f"{provider.plugin_id}/{provider.provider}" + provider_schema.provider_name = provider.provider + providers.append(provider_schema) return providers monkeypatch.setattr( ModelProviderFactory, - "get_plugin_model_providers", - mock_plugin_model_providers, + "get_model_providers", + mock_model_providers, ) # Create actual provider and model type instances - model_provider_factory = ModelProviderFactory(tenant_id="test") - provider_instance = model_provider_factory.get_plugin_model_provider("openai") + model_provider_factory = ModelProviderFactory(model_runtime=create_plugin_model_runtime(tenant_id="test")) + provider_instance = model_provider_factory.get_model_provider("openai") model_type_instance = model_provider_factory.get_model_type_instance("openai", ModelType.LLM) # Create a ProviderModelBundle provider_model_bundle = ProviderModelBundle( configuration=ProviderConfiguration( tenant_id="1", - provider=provider_instance.declaration, + provider=provider_instance, preferred_provider_type=ProviderType.CUSTOM, using_provider_type=ProviderType.CUSTOM, system_configuration=SystemConfiguration(enabled=False), @@ -181,13 +280,18 @@ def model_config(monkeypatch): ) -def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity): +def test_fetch_model_config_hydrates_model_instance_runtime_settings(model_config: ModelConfigWithCredentialsEntity): mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) - mock_model_factory = mock.MagicMock(spec=ModelFactory) + mock_model_factory = mock.MagicMock(spec=DifyModelFactory) provider_model_bundle = model_config.provider_model_bundle model_type_instance = provider_model_bundle.model_type_instance provider_model = mock.MagicMock() + completion_params = { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } model_instance = mock.MagicMock( model_type_instance=model_type_instance, @@ -208,12 +312,36 @@ def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsE model_type_instance.__class__, "get_model_schema", return_value=model_config.model_schema, autospec=True ), ): - fetch_model_config( - node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + hydrated_model_instance, model_config_with_credentials = fetch_model_config( + node_data_model=ModelConfig( + provider="openai", + name="gpt-3.5-turbo", + mode="chat", + completion_params=completion_params, + ), credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, ) + assert hydrated_model_instance is model_instance + assert hydrated_model_instance.provider == "openai" + assert hydrated_model_instance.model_name == "gpt-3.5-turbo" + assert hydrated_model_instance.credentials == {"api_key": "test"} + assert hydrated_model_instance.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert hydrated_model_instance.stop == ("Observation:", "Human:") + assert model_config_with_credentials.parameters == { + "temperature": 0.7, + "max_tokens": 256, + } + assert model_config_with_credentials.stop == ["Observation:", "Human:"] + assert completion_params == { + "temperature": 0.7, + "max_tokens": 256, + "stop": ["Observation:", "Human:"], + } mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo") mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo") provider_model.raise_for_status.assert_called_once() @@ -230,12 +358,20 @@ def test_dify_model_access_adapters_call_managers(): mock_provider_configuration.get_provider_model.return_value = mock_provider_model mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"} - credentials_provider = DifyCredentialsProvider( + run_context = DifyRunContext( tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + credentials_provider = DifyCredentialsProvider( + run_context=run_context, provider_manager=mock_provider_manager, ) model_factory = DifyModelFactory( - tenant_id="tenant", + run_context=run_context, model_manager=mock_model_manager, ) @@ -255,18 +391,18 @@ def test_dify_model_access_adapters_call_managers(): model="gpt-3.5-turbo", ) mock_provider_model.raise_for_status.assert_called_once() - mock_model_manager.get_model_instance.assert_called_once_with( - tenant_id="tenant", - provider="openai", - model_type=ModelType.LLM, - model="gpt-3.5-turbo", - ) + mock_model_manager.get_model_instance.assert_called_once() + assert mock_model_manager.get_model_instance.call_args.kwargs == { + "tenant_id": "tenant", + "provider": "openai", + "model_type": ModelType.LLM, + "model": "gpt-3.5-turbo", + } def test_fetch_files_with_file_segment(): file = File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -284,7 +420,6 @@ def test_fetch_files_with_array_file_segment(): files = [ File( id="1", - tenant_id="test", type=FileType.IMAGE, filename="test1.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -293,7 +428,6 @@ def test_fetch_files_with_array_file_segment(): ), File( id="2", - tenant_id="test", type=FileType.IMAGE, filename="test2.jpg", transfer_method=FileTransferMethod.LOCAL_FILE, @@ -343,7 +477,6 @@ def test_fetch_files_with_non_existent_variable(): # files = [ # File( # id="1", -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -448,7 +581,6 @@ def test_fetch_files_with_non_existent_variable(): # sys_query=fake_query, # sys_files=[ # File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -524,7 +656,6 @@ def test_fetch_files_with_non_existent_variable(): # + [UserPromptMessage(content=fake_query)], # file_variables={ # "input.image": File( -# tenant_id="test", # type=FileType.IMAGE, # filename="test1.jpg", # transfer_method=FileTransferMethod.REMOTE_URL, @@ -569,7 +700,7 @@ def test_fetch_files_with_non_existent_variable(): def test_handle_list_messages_basic(llm_node): messages = [ LLMNodeChatModelMessage( - text="Hello, {#context#}", + text="Hello, {{#context#}}", role=PromptMessageRole.USER, edition_type="basic", ) @@ -592,32 +723,414 @@ def test_handle_list_messages_basic(llm_node): assert result[0].content == [TextPromptMessageContent(data="Hello, world")] -def test_handle_list_messages_jinja2_uses_template_renderer(llm_node): - llm_node._template_renderer.render_jinja2.return_value = "Hello, world" +def test_handle_list_messages_replaces_double_brace_context_placeholder(llm_node): messages = [ LLMNodeChatModelMessage( - text="", - jinja2_text="Hello, {{ name }}", - role=PromptMessageRole.USER, - edition_type="jinja2", + text="Answer user's question with the following context:\n\n{{#context#}}", + role=PromptMessageRole.SYSTEM, + edition_type="basic", ) ] + context = "## Overview\nSends a JSON request." result = llm_node.handle_list_messages( messages=messages, - context=None, + context=context, jinja2_variables=[], variable_pool=llm_node.graph_runtime_state.variable_pool, vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, - template_renderer=llm_node._template_renderer, ) - assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])] - llm_node._template_renderer.render_jinja2.assert_called_once_with( - template="Hello, {{ name }}", - inputs={}, + assert len(result) == 1 + assert isinstance(result[0].content, list) + assert result[0].content == [ + TextPromptMessageContent( + data="Answer user's question with the following context:\n\n## Overview\nSends a JSON request." + ) + ] + + +def test_handle_list_messages_renders_jinja2_messages(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + prompt_messages = llm_node.handle_list_messages( + messages=[ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ) + ], + context="", + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])], + variable_pool=llm_node.graph_runtime_state.variable_pool, + vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH, + jinja2_template_renderer=renderer, ) + assert prompt_messages == [ + SystemPromptMessage(content=[TextPromptMessageContent(data="Hello Dify")]), + ] + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_transform_chat_messages_prefers_jinja2_text(llm_node): + completion_template = LLMNodeCompletionModelPromptTemplate( + text="ignored", + jinja2_text="completion prompt", + edition_type="jinja2", + ) + chat_messages = [ + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="chat prompt", + role=PromptMessageRole.USER, + edition_type="jinja2", + ), + LLMNodeChatModelMessage( + text="keep original", + role=PromptMessageRole.SYSTEM, + edition_type="basic", + ), + ] + + transformed_completion = llm_node._transform_chat_messages(completion_template) + transformed_messages = llm_node._transform_chat_messages(chat_messages) + + assert transformed_completion.text == "completion prompt" + assert transformed_messages[0].text == "chat prompt" + assert transformed_messages[1].text == "keep original" + + +def test_fetch_jinja_inputs_serializes_supported_segment_types(llm_node): + llm_node.graph_runtime_state.variable_pool.add( + ["input", "items"], + ["alpha", {"metadata": {"_source": "knowledge"}, "content": "beta"}, 3], + ) + llm_node.graph_runtime_state.variable_pool.add( + ["input", "context_doc"], + {"metadata": {"_source": "knowledge"}, "content": "context body"}, + ) + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"a": 1}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[ + VariableSelector(variable="items", value_selector=["input", "items"]), + VariableSelector(variable="context_doc", value_selector=["input", "context_doc"]), + VariableSelector(variable="payload", value_selector=["input", "payload"]), + ] + ) + } + ) + + assert llm_node._fetch_jinja_inputs(node_data) == { + "items": "alpha\nbeta\n3", + "context_doc": "context body", + "payload": '{"a": 1}', + } + + +def test_fetch_jinja_inputs_raises_for_missing_variable(llm_node): + node_data = llm_node.node_data.model_copy( + update={ + "prompt_config": PromptConfig( + jinja2_variables=[VariableSelector(variable="missing", value_selector=["input", "missing"])] + ) + } + ) + + with pytest.raises(VariableNotFoundError, match="Variable missing not found"): + llm_node._fetch_jinja_inputs(node_data) + + +def test_fetch_inputs_collects_prompt_and_memory_variables(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["input", "name"], "Dify") + llm_node.graph_runtime_state.variable_pool.add(["input", "payload"], {"active": True}) + + node_data = llm_node.node_data.model_copy( + update={ + "prompt_template": [ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}} with {{#input.payload#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + "memory": MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#input.name#}}", + ), + } + ) + + assert llm_node._fetch_inputs(node_data) == { + "#input.name#": "Dify", + "#input.payload#": {"active": True}, + } + + +def test_fetch_context_emits_string_context_event(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], "retrieved context") + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert events == [ + RunRetrieverResourceEvent(retriever_resources=[], context="retrieved context", context_files=[]), + ] + + +def test_fetch_context_collects_retriever_resources_and_attachments(llm_node): + attachment = _build_image_file( + file_id="attachment", + related_id="attachment-related", + remote_url="https://example.com/attachment.png", + ) + llm_node._retriever_attachment_loader = mock.MagicMock() + llm_node._retriever_attachment_loader.load.return_value = [attachment] + + llm_node.graph_runtime_state.variable_pool.add( + ["context", "value"], + [ + { + "content": "chunk body", + "summary": "chunk summary", + "files": [{"id": "file-1"}], + "metadata": { + "_source": "knowledge", + "dataset_id": "dataset-1", + "segment_id": "segment-1", + "segment_word_count": 12, + }, + }, + "tail text", + ], + ) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + events = list(llm_node._fetch_context(node_data)) + + assert len(events) == 1 + event = events[0] + assert event.context == "chunk summary\nchunk body\ntail text" + assert event.context_files == [attachment] + assert event.retriever_resources == [ + { + "position": None, + "dataset_id": "dataset-1", + "dataset_name": None, + "document_id": None, + "document_name": None, + "data_source_type": None, + "segment_id": "segment-1", + "retriever_from": None, + "score": None, + "hit_count": None, + "word_count": 12, + "segment_position": None, + "index_node_hash": None, + "content": "chunk body", + "page": None, + "doc_metadata": None, + "files": [{"id": "file-1"}], + "summary": "chunk summary", + } + ] + llm_node._retriever_attachment_loader.load.assert_called_once_with(segment_id="segment-1") + + +def test_fetch_context_rejects_invalid_context_structure(llm_node): + llm_node.graph_runtime_state.variable_pool.add(["context", "value"], [{"summary": "missing content"}]) + node_data = llm_node.node_data.model_copy( + update={"context": ContextConfig(enabled=True, variable_selector=["context", "value"])} + ) + + with pytest.raises(InvalidContextStructureError, match="Invalid context structure"): + list(llm_node._fetch_context(node_data)) + + +def test_fetch_prompt_messages_chat_mode_appends_memory_query_and_files(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[ModelFeature.VISION]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [AssistantPromptMessage(content="history answer")] + + sys_file = _build_image_file(file_id="sys-file", related_id="sys-related", remote_url="https://example.com/sys.png") + context_file = _build_image_file( + file_id="context-file", + related_id="context-related", + remote_url="https://example.com/context.png", + ) + + prompt_content_side_effect = [ + ImagePromptMessageContent( + url="https://example.com/sys.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ImagePromptMessageContent( + url="https://example.com/context.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ] + + with mock.patch("graphon.nodes.llm.node.file_manager.to_prompt_message_content") as mock_to_prompt: + mock_to_prompt.side_effect = prompt_content_side_effect + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="current question", + sys_files=[sys_file], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="Before query", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=("STOP",), + memory_config=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=False), + ), + vision_enabled=True, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + context_files=[context_file], + ) + + assert stop == ("STOP",) + assert prompt_messages[0] == UserPromptMessage(content="Before query") + assert prompt_messages[1] == AssistantPromptMessage(content="history answer") + assert isinstance(prompt_messages[2], UserPromptMessage) + assert isinstance(prompt_messages[2].content, list) + assert isinstance(prompt_messages[2].content[0], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[1], ImagePromptMessageContent) + assert isinstance(prompt_messages[2].content[2], TextPromptMessageContent) + assert prompt_messages[2].content[0].url == "https://example.com/context.png" + assert prompt_messages[2].content[1].url == "https://example.com/sys.png" + assert prompt_messages[2].content[2].data == "current question" + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=2000, message_limit=None) + + +def test_fetch_prompt_messages_completion_mode_injects_histories_and_query(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + memory = mock.MagicMock(spec=MockTokenBufferMemory) + memory.get_history_prompt_messages.return_value = [ + UserPromptMessage(content="previous question"), + AssistantPromptMessage(content="previous answer"), + ] + + prompt_messages, stop = LLMNode.fetch_prompt_messages( + sys_query="latest question", + sys_files=[], + context="", + memory=memory, + model_instance=model_instance, + prompt_template=LLMNodeCompletionModelPromptTemplate( + text="Prompt header\n#histories#", + edition_type="basic", + ), + stop=("HALT",), + memory_config=MemoryConfig( + role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), + window=MemoryConfig.WindowConfig(enabled=True, size=2), + ), + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=VariablePool.empty(), + jinja2_variables=[], + ) + + assert stop == ("HALT",) + assert prompt_messages == [ + UserPromptMessage( + content="latest question\nPrompt header\nHuman: previous question\nAssistant: previous answer" + ) + ] + + +def test_fetch_prompt_messages_raises_when_only_unsupported_content_remains(): + model_instance = _build_prepared_llm_mock() + model_instance.get_model_schema.return_value = _build_model_schema(features=[]) + + variable_pool = VariablePool.empty() + variable_pool.add( + ["input", "image"], + _build_image_file(file_id="image-file", related_id="image-related", remote_url="https://example.com/file.png"), + ) + + with ( + mock.patch( + "graphon.nodes.llm.node.file_manager.to_prompt_message_content", + return_value=ImagePromptMessageContent( + url="https://example.com/file.png", + format="png", + mime_type="image/png", + detail=ImagePromptMessageContent.DETAIL.HIGH, + ), + ), + pytest.raises(NoPromptFoundError, match="No prompt found"), + ): + LLMNode.fetch_prompt_messages( + sys_query=None, + sys_files=[], + context="", + memory=None, + model_instance=model_instance, + prompt_template=[ + LLMNodeChatModelMessage( + text="{{#input.image#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ) + ], + stop=None, + memory_config=None, + vision_enabled=False, + vision_detail=ImagePromptMessageContent.DETAIL.HIGH, + variable_pool=variable_pool, + jinja2_variables=[], + ) + + +def test_handle_completion_template_replaces_double_brace_context_placeholder(llm_node): + prompt_messages = _handle_completion_template( + template=LLMNodeCompletionModelPromptTemplate( + text="Summarize the following context:\n{{#context#}}", + edition_type="basic", + ), + context="## Overview\nSends a JSON request.", + jinja2_variables=[], + variable_pool=llm_node.graph_runtime_state.variable_pool, + jinja2_template_renderer=None, + ) + + assert prompt_messages == [ + UserPromptMessage( + content=[ + TextPromptMessageContent(data="Summarize the following context:\n## Overview\nSends a JSON request.") + ] + ) + ] + def test_handle_memory_completion_mode_uses_prompt_message_interface(): memory = mock.MagicMock(spec=MockTokenBufferMemory) @@ -635,15 +1148,15 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface(): AssistantPromptMessage(content="first answer"), ] - model_instance = mock.MagicMock(spec=ModelInstance) + model_instance = _build_prepared_llm_mock() memory_config = MemoryConfig( role_prefix=MemoryConfig.RolePrefix(user="Human", assistant="Assistant"), window=MemoryConfig.WindowConfig(enabled=True, size=3), ) - with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token: - memory_text = llm_utils.handle_memory_completion_mode( + with mock.patch("graphon.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token: + memory_text = _handle_memory_completion_mode( memory=memory, memory_config=memory_config, model_instance=model_instance, @@ -659,7 +1172,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver) mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider) mock_model_factory = mock.MagicMock(spec=ModelFactory) - mock_template_renderer = mock.MagicMock(spec=TemplateRenderer) + mock_prompt_message_serializer = mock.MagicMock(spec=PromptMessageSerializerProtocol) node_config = { "id": "1", "data": llm_node_data.model_dump(), @@ -672,9 +1185,9 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat graph_runtime_state=graph_runtime_state, credentials_provider=mock_credentials_provider, model_factory=mock_model_factory, - model_instance=mock.MagicMock(spec=ModelInstance), + model_instance=_build_prepared_llm_mock(), llm_file_saver=mock_file_saver, - template_renderer=mock_template_renderer, + prompt_message_serializer=mock_prompt_message_serializer, http_client=http_client, ) return node, mock_file_saver @@ -690,7 +1203,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -721,7 +1233,6 @@ class TestLLMNodeSaveMultiModalImageOutput: ) mock_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, related_id=str(uuid.uuid4()), @@ -776,7 +1287,6 @@ class TestSaveMultimodalOutputAndConvertResultToMarkdown: mock_saved_file = File( id=str(uuid.uuid4()), - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.TOOL_FILE, filename="test.png", @@ -906,3 +1416,322 @@ class TestReasoningFormat: assert clean_text == text_with_think assert reasoning_content == "" + + +@pytest.mark.parametrize( + ("structured_output_enabled", "structured_output"), + [ + (False, None), + (True, {"schema": {"type": "object", "properties": {"answer": {"type": "string"}}}}), + ], +) +def test_invoke_llm_dispatches_to_expected_model_method(structured_output_enabled, structured_output): + model_instance = _build_prepared_llm_mock() + prompt_messages = [UserPromptMessage(content="hello")] + file_saver = mock.MagicMock(spec=LLMFileSaver) + + model_instance.invoke_llm.return_value = iter([]) + model_instance.invoke_llm_with_structured_output.return_value = iter([]) + + with ( + mock.patch.object(LLMNode, "handle_invoke_result", return_value=iter(["handled"])) as mock_handle, + mock.patch("graphon.nodes.llm.node.time.perf_counter", return_value=10.0), + ): + result = list( + LLMNode.invoke_llm( + model_instance=model_instance, + prompt_messages=prompt_messages, + stop=("STOP",), + structured_output_enabled=structured_output_enabled, + structured_output=structured_output, + file_saver=file_saver, + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + reasoning_format="separated", + ) + ) + + assert result == ["handled"] + if structured_output_enabled: + model_instance.invoke_llm_with_structured_output.assert_called_once_with( + prompt_messages=prompt_messages, + json_schema={"type": "object", "properties": {"answer": {"type": "string"}}}, + model_parameters={}, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm.assert_not_called() + else: + model_instance.invoke_llm.assert_called_once_with( + prompt_messages=prompt_messages, + model_parameters={}, + tools=None, + stop=("STOP",), + stream=True, + ) + model_instance.invoke_llm_with_structured_output.assert_not_called() + + assert mock_handle.call_args.kwargs["request_start_time"] == 10.0 + + +def test_handle_invoke_result_streaming_collects_text_metrics_and_structured_output(): + usage = LLMUsage.from_metadata({"prompt_tokens": 12, "completion_tokens": 4, "total_tokens": 16}) + first_chunk = LLMResultChunkWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=0, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="plan")]), + ), + structured_output={"draft": True}, + ) + final_chunk = LLMResultChunk( + model="gpt-3.5-turbo", + prompt_messages=[], + delta=LLMResultChunkDelta( + index=1, + message=AssistantPromptMessage(content=[TextPromptMessageContent(data="answer")]), + usage=usage, + finish_reason="stop", + ), + ) + + with mock.patch("graphon.nodes.llm.node.time.perf_counter", side_effect=[2.0, 5.0]): + events = list( + LLMNode.handle_invoke_result( + invoke_result=iter([first_chunk, final_chunk]), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=_build_prepared_llm_mock(), + reasoning_format="separated", + request_start_time=1.0, + ) + ) + + assert events[0] == first_chunk + assert events[1] == StreamChunkEvent(selector=["node-1", "text"], chunk="plan", is_final=False) + assert events[2] == StreamChunkEvent(selector=["node-1", "text"], chunk="answer", is_final=False) + + completed = events[3] + assert isinstance(completed, ModelInvokeCompletedEvent) + assert completed.text == "answer" + assert completed.reasoning_content == "plan" + assert completed.structured_output == {"draft": True} + assert completed.finish_reason == "stop" + assert completed.usage.total_tokens == 16 + assert completed.usage.latency == 4.0 + assert completed.usage.time_to_first_token == 1.0 + assert completed.usage.time_to_generate == 3.0 + + +def test_handle_invoke_result_wraps_structured_output_parse_errors(): + model_instance = _build_prepared_llm_mock() + model_instance.is_structured_output_parse_error.return_value = True + + def broken_stream(): + raise ValueError("bad json") + yield + + with pytest.raises(LLMNodeError, match="Failed to parse structured output: bad json"): + list( + LLMNode.handle_invoke_result( + invoke_result=broken_stream(), + file_saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + node_id="node-1", + node_type=LLMNode.node_type, + model_instance=model_instance, + ) + ) + + +def test_handle_blocking_result_extracts_reasoning_and_structured_output(): + invoke_result = LLMResultWithStructuredOutput( + model="gpt-3.5-turbo", + prompt_messages=[], + message=AssistantPromptMessage(content="reasoningfinal answer"), + usage=LLMUsage.empty_usage(), + structured_output={"answer": "final answer"}, + ) + + event = LLMNode.handle_blocking_result( + invoke_result=invoke_result, + saver=mock.MagicMock(spec=LLMFileSaver), + file_outputs=[], + reasoning_format="separated", + request_latency=1.2345, + ) + + assert event.text == "final answer" + assert event.reasoning_content == "reasoning" + assert event.structured_output == {"answer": "final answer"} + assert event.usage.latency == 1.234 + + +def test_fetch_structured_output_schema_validates_payload(): + assert LLMNode.fetch_structured_output_schema(structured_output={"schema": {"type": "object"}}) == { + "type": "object" + } + + with pytest.raises(LLMNodeError, match="Please provide a valid structured output schema"): + LLMNode.fetch_structured_output_schema(structured_output={}) + + with pytest.raises(LLMNodeError, match="structured_output_schema must be a JSON object"): + LLMNode.fetch_structured_output_schema(structured_output={"schema": ["not", "an", "object"]}) + + +def test_extract_variable_selector_to_variable_mapping_includes_runtime_selectors(): + node_data = LLMNodeData( + title="Test LLM", + model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}), + prompt_template=[ + LLMNodeChatModelMessage( + text="Hello {{#input.name#}}", + role=PromptMessageRole.USER, + edition_type="basic", + ), + LLMNodeChatModelMessage( + text="ignored", + jinja2_text="Hello {{ name }}", + role=PromptMessageRole.SYSTEM, + edition_type="jinja2", + ), + ], + prompt_config=PromptConfig( + jinja2_variables=[VariableSelector(variable="name", value_selector=["input", "name"])] + ), + memory=MemoryConfig( + window=MemoryConfig.WindowConfig(enabled=True, size=1), + query_prompt_template="Repeat {{#sys.query#}}", + ), + context=ContextConfig(enabled=True, variable_selector=["context", "value"]), + vision=VisionConfig(enabled=True), + ) + + mapping = LLMNode._extract_variable_selector_to_variable_mapping( + graph_config={}, + node_id="llm-1", + node_data=node_data, + ) + + assert mapping == { + "llm-1.#input.name#": ["input", "name"], + "llm-1.#sys.query#": ["sys", "query"], + "llm-1.#context#": ["context", "value"], + "llm-1.#files#": ["sys", "files"], + "llm-1.name": ["input", "name"], + } + + +def test_render_jinja2_message_requires_renderer_and_passes_inputs(): + variable_pool = VariablePool.empty() + variable_pool.add(["input", "name"], "Dify") + variables = [VariableSelector(variable="name", value_selector=["input", "name"])] + + with pytest.raises( + TemplateRenderError, + match="LLMNode requires an injected jinja2_template_renderer for jinja2 prompts", + ): + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=None, + ) + + renderer = mock.MagicMock() + renderer.render_template.return_value = "Hello Dify" + + assert ( + _render_jinja2_message( + template="Hello {{ name }}", + jinja2_variables=variables, + variable_pool=variable_pool, + jinja2_template_renderer=renderer, + ) + == "Hello Dify" + ) + renderer.render_template.assert_called_once_with("Hello {{ name }}", {"name": "Dify"}) + + +def test_calculate_rest_token_uses_context_size_and_max_tokens(): + model_instance = _build_prepared_llm_mock() + model_instance.parameters = {"max_tokens": 512} + model_instance.get_model_schema.return_value = _build_model_schema( + model_properties={ModelPropertyKey.CONTEXT_SIZE: 4096}, + parameter_rules=[ + ParameterRule( + name="max_tokens", + label=I18nObject(en_US="Max Tokens"), + type=ParameterType.INT, + ) + ], + ) + model_instance.get_llm_num_tokens.return_value = 1000 + + assert ( + _calculate_rest_token( + prompt_messages=[UserPromptMessage(content="hello")], + model_instance=model_instance, + ) + == 2584 + ) + + +def test_handle_memory_chat_mode_uses_calculated_token_budget(): + memory = mock.MagicMock(spec=MockTokenBufferMemory) + history = [UserPromptMessage(content="question")] + memory.get_history_prompt_messages.return_value = history + + with mock.patch("graphon.nodes.llm.node._calculate_rest_token", return_value=321) as mock_rest_token: + result = _handle_memory_chat_mode( + memory=memory, + memory_config=MemoryConfig(window=MemoryConfig.WindowConfig(enabled=True, size=2)), + model_instance=_build_prepared_llm_mock(), + ) + + assert result == history + mock_rest_token.assert_called_once() + memory.get_history_prompt_messages.assert_called_once_with(max_token_limit=321, message_limit=2) + + +def test_dify_model_access_adapters_skip_runtime_build_when_managers_are_injected(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager_factory: + DifyCredentialsProvider(run_context=run_context, provider_manager=mock.MagicMock()) + DifyModelFactory(run_context=run_context, model_manager=mock.MagicMock()) + + mock_provider_manager_factory.assert_not_called() + + +def test_build_dify_model_access_binds_run_context_user_id_once(): + run_context = DifyRunContext( + tenant_id="tenant", + app_id="app", + user_id="user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + ) + + with mock.patch("core.app.llm.model_access.create_plugin_provider_manager") as mock_provider_manager: + build_dify_model_access(run_context) + + mock_provider_manager.assert_called_once_with(tenant_id="tenant", user_id="user") + + +def test_dify_model_access_requires_run_context_argument(): + with pytest.raises(TypeError): + DifyCredentialsProvider() + + with pytest.raises(TypeError): + DifyModelFactory() diff --git a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py b/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py deleted file mode 100644 index e40d565ef5..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/llm/test_scenarios.py +++ /dev/null @@ -1,25 +0,0 @@ -from collections.abc import Mapping, Sequence - -from pydantic import BaseModel, Field - -from dify_graph.file import File -from dify_graph.model_runtime.entities.message_entities import PromptMessage -from dify_graph.model_runtime.entities.model_entities import ModelFeature -from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage - - -class LLMNodeTestScenario(BaseModel): - """Test scenario for LLM node testing.""" - - description: str = Field(..., description="Description of the test scenario") - sys_query: str = Field(..., description="User query input") - sys_files: Sequence[File] = Field(default_factory=list, description="List of user files") - vision_enabled: bool = Field(default=False, description="Whether vision is enabled") - vision_detail: str | None = Field(None, description="Vision detail level if vision is enabled") - features: Sequence[ModelFeature] = Field(default_factory=list, description="List of model features") - window_size: int = Field(..., description="Window size for memory") - prompt_template: Sequence[LLMNodeChatModelMessage] = Field(..., description="Template for prompt messages") - file_variables: Mapping[str, File | Sequence[File]] = Field( - default_factory=dict, description="List of file variables" - ) - expected_messages: Sequence[PromptMessage] = Field(..., description="Expected messages after processing") diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py deleted file mode 100644 index fd48edc58c..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_entities.py +++ /dev/null @@ -1,27 +0,0 @@ -from dify_graph.nodes.parameter_extractor.entities import ParameterConfig -from dify_graph.variables.types import SegmentType - - -class TestParameterConfig: - def test_select_type(self): - data = { - "name": "yes_or_no", - "type": "select", - "options": ["yes", "no"], - "description": "a simple select made of `yes` and `no`", - "required": True, - } - - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.STRING - assert pc.options == data["options"] - - def test_validate_bool_type(self): - data = { - "name": "boolean", - "type": "bool", - "description": "a simple boolean parameter", - "required": True, - } - pc = ParameterConfig.model_validate(data) - assert pc.type == SegmentType.BOOLEAN diff --git a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py index 7eca531b62..1c362a0a03 100644 --- a/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/parameter_extractor/test_parameter_extractor_node.py @@ -6,18 +6,18 @@ from dataclasses import dataclass from typing import Any import pytest - -from dify_graph.model_runtime.entities import LLMMode -from dify_graph.nodes.llm import ModelConfig, VisionConfig -from dify_graph.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData -from dify_graph.nodes.parameter_extractor.exc import ( +from graphon.model_runtime.entities import LLMMode +from graphon.nodes.llm import ModelConfig, VisionConfig +from graphon.nodes.parameter_extractor.entities import ParameterConfig, ParameterExtractorNodeData +from graphon.nodes.parameter_extractor.exc import ( InvalidNumberOfParametersError, InvalidSelectValueError, InvalidValueTypeError, RequiredParameterMissingError, ) -from dify_graph.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode -from dify_graph.variables.types import SegmentType +from graphon.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode +from graphon.variables.types import SegmentType + from factories.variable_factory import build_segment_with_type diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py deleted file mode 100644 index e57ebbd83e..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/entities_spec.py +++ /dev/null @@ -1,225 +0,0 @@ -import pytest -from pydantic import ValidationError - -from dify_graph.enums import ErrorStrategy -from dify_graph.nodes.template_transform.entities import TemplateTransformNodeData - - -class TestTemplateTransformNodeData: - """Test suite for TemplateTransformNodeData entity.""" - - def test_valid_template_transform_node_data(self): - """Test creating valid TemplateTransformNodeData.""" - data = { - "title": "Template Transform", - "desc": "Transform data using Jinja2 template", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "age", "value_selector": ["sys", "user_age"]}, - ], - "template": "Hello {{ name }}, you are {{ age }} years old!", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Template Transform" - assert node_data.desc == "Transform data using Jinja2 template" - assert len(node_data.variables) == 2 - assert node_data.variables[0].variable == "name" - assert node_data.variables[0].value_selector == ["sys", "user_name"] - assert node_data.variables[1].variable == "age" - assert node_data.variables[1].value_selector == ["sys", "user_age"] - assert node_data.template == "Hello {{ name }}, you are {{ age }} years old!" - - def test_template_transform_node_data_with_empty_variables(self): - """Test TemplateTransformNodeData with no variables.""" - data = { - "title": "Static Template", - "variables": [], - "template": "This is a static template with no variables.", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Static Template" - assert len(node_data.variables) == 0 - assert node_data.template == "This is a static template with no variables." - - def test_template_transform_node_data_with_complex_template(self): - """Test TemplateTransformNodeData with complex Jinja2 template.""" - data = { - "title": "Complex Template", - "variables": [ - {"variable": "items", "value_selector": ["sys", "item_list"]}, - {"variable": "total", "value_selector": ["sys", "total_count"]}, - ], - "template": ( - "{% for item in items %}{{ item }}{% if not loop.last %}, {% endif %}{% endfor %}. Total: {{ total }}" - ), - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.title == "Complex Template" - assert len(node_data.variables) == 2 - assert "{% for item in items %}" in node_data.template - assert "{{ total }}" in node_data.template - - def test_template_transform_node_data_with_error_strategy(self): - """Test TemplateTransformNodeData with error handling strategy.""" - data = { - "title": "Template with Error Handling", - "variables": [{"variable": "value", "value_selector": ["sys", "input"]}], - "template": "{{ value }}", - "error_strategy": "fail-branch", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.error_strategy == ErrorStrategy.FAIL_BRANCH - - def test_template_transform_node_data_with_retry_config(self): - """Test TemplateTransformNodeData with retry configuration.""" - data = { - "title": "Template with Retry", - "variables": [{"variable": "data", "value_selector": ["sys", "data"]}], - "template": "{{ data }}", - "retry_config": {"enabled": True, "max_retries": 3, "retry_interval": 1000}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.retry_config.enabled is True - assert node_data.retry_config.max_retries == 3 - assert node_data.retry_config.retry_interval == 1000 - - def test_template_transform_node_data_missing_required_fields(self): - """Test that missing required fields raises ValidationError.""" - data = { - "title": "Incomplete Template", - # Missing 'variables' and 'template' - } - - with pytest.raises(ValidationError) as exc_info: - TemplateTransformNodeData.model_validate(data) - - errors = exc_info.value.errors() - assert len(errors) >= 2 - error_fields = {error["loc"][0] for error in errors} - assert "variables" in error_fields - assert "template" in error_fields - - def test_template_transform_node_data_invalid_variable_selector(self): - """Test that invalid variable selector format raises ValidationError.""" - data = { - "title": "Invalid Variable", - "variables": [ - {"variable": "name", "value_selector": "invalid_format"} # Should be list - ], - "template": "{{ name }}", - } - - with pytest.raises(ValidationError): - TemplateTransformNodeData.model_validate(data) - - def test_template_transform_node_data_with_default_value_dict(self): - """Test TemplateTransformNodeData with default value dictionary.""" - data = { - "title": "Template with Defaults", - "variables": [ - {"variable": "name", "value_selector": ["sys", "user_name"]}, - {"variable": "greeting", "value_selector": ["sys", "greeting"]}, - ], - "template": "{{ greeting }} {{ name }}!", - "default_value_dict": {"greeting": "Hello", "name": "Guest"}, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.default_value_dict == {"greeting": "Hello", "name": "Guest"} - - def test_template_transform_node_data_with_nested_selectors(self): - """Test TemplateTransformNodeData with nested variable selectors.""" - data = { - "title": "Nested Selectors", - "variables": [ - {"variable": "user_info", "value_selector": ["sys", "user", "profile", "name"]}, - {"variable": "settings", "value_selector": ["sys", "config", "app", "theme"]}, - ], - "template": "User: {{ user_info }}, Theme: {{ settings }}", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert len(node_data.variables) == 2 - assert node_data.variables[0].value_selector == ["sys", "user", "profile", "name"] - assert node_data.variables[1].value_selector == ["sys", "config", "app", "theme"] - - def test_template_transform_node_data_with_multiline_template(self): - """Test TemplateTransformNodeData with multiline template.""" - data = { - "title": "Multiline Template", - "variables": [ - {"variable": "title", "value_selector": ["sys", "title"]}, - {"variable": "content", "value_selector": ["sys", "content"]}, - ], - "template": """ -# {{ title }} - -{{ content }} - ---- -Generated by Template Transform Node - """, - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "# {{ title }}" in node_data.template - assert "{{ content }}" in node_data.template - assert "Generated by Template Transform Node" in node_data.template - - def test_template_transform_node_data_serialization(self): - """Test that TemplateTransformNodeData can be serialized and deserialized.""" - original_data = { - "title": "Serialization Test", - "desc": "Test serialization", - "variables": [{"variable": "test", "value_selector": ["sys", "test"]}], - "template": "{{ test }}", - } - - node_data = TemplateTransformNodeData.model_validate(original_data) - serialized = node_data.model_dump() - deserialized = TemplateTransformNodeData.model_validate(serialized) - - assert deserialized.title == node_data.title - assert deserialized.desc == node_data.desc - assert len(deserialized.variables) == len(node_data.variables) - assert deserialized.template == node_data.template - - def test_template_transform_node_data_with_special_characters(self): - """Test TemplateTransformNodeData with special characters in template.""" - data = { - "title": "Special Characters", - "variables": [{"variable": "text", "value_selector": ["sys", "input"]}], - "template": "Special: {{ text }} | Symbols: @#$%^&*() | Unicode: 你好 🎉", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert "@#$%^&*()" in node_data.template - assert "你好" in node_data.template - assert "🎉" in node_data.template - - def test_template_transform_node_data_empty_template(self): - """Test TemplateTransformNodeData with empty template string.""" - data = { - "title": "Empty Template", - "variables": [], - "template": "", - } - - node_data = TemplateTransformNodeData.model_validate(data) - - assert node_data.template == "" - assert len(node_data.variables) == 0 diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py index 332a8761f9..d86e0efe02 100644 --- a/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/template_transform_node_spec.py @@ -1,13 +1,15 @@ from unittest.mock import MagicMock import pytest +from graphon.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus +from graphon.graph import Graph +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.entities import TemplateTransformNodeData +from graphon.nodes.template_transform.template_transform_node import TemplateTransformNode +from graphon.runtime import GraphRuntimeState +from graphon.template_rendering import TemplateRenderError from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.enums import BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionStatus -from dify_graph.graph import Graph -from dify_graph.nodes.template_transform.template_renderer import TemplateRenderError -from dify_graph.nodes.template_transform.template_transform_node import TemplateTransformNode -from dify_graph.runtime import GraphRuntimeState from tests.workflow_test_utils import build_test_graph_init_params @@ -62,7 +64,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node.node_type == BuiltinNodeTypes.TEMPLATE_TRANSFORM @@ -78,7 +80,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_title() == "Template Transform" @@ -91,7 +93,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_description() == "Transform data using template" @@ -111,7 +113,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) assert node._get_error_strategy() == ErrorStrategy.FAIL_BRANCH @@ -130,6 +132,26 @@ class TestTemplateTransformNode: """Test version class method.""" assert TemplateTransformNode.version() == "1" + @pytest.mark.parametrize("max_output_length", [0, -1]) + def test_node_initialization_rejects_non_positive_max_output_length( + self, + basic_node_data, + mock_graph_runtime_state, + graph_init_params, + max_output_length, + ): + mock_renderer = MagicMock() + + with pytest.raises(ValueError, match="max_output_length must be a positive integer"): + TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=max_output_length, + ) + def test_run_simple_template(self, basic_node_data, mock_graph_runtime_state, graph_init_params): """Test _run with simple template transformation using injected renderer.""" # Setup mock variable pool @@ -153,7 +175,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -181,7 +203,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -201,7 +223,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -221,7 +243,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": basic_node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, max_output_length=10, ) @@ -230,6 +252,28 @@ class TestTemplateTransformNode: assert result.status == WorkflowNodeExecutionStatus.FAILED assert "Output length exceeds" in result.error + def test_run_output_length_equal_to_limit_succeeds( + self, basic_node_data, mock_graph_runtime_state, graph_init_params + ): + mock_graph_runtime_state.variable_pool.get.return_value = MagicMock() + + mock_renderer = MagicMock() + mock_renderer.render_template.return_value = "1234567890" + + node = TemplateTransformNode( + id="test_node", + config={"id": "test_node", "data": basic_node_data}, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=mock_renderer, + max_output_length=10, + ) + + result = node._run() + + assert result.status == WorkflowNodeExecutionStatus.SUCCEEDED + assert result.outputs["output"] == "1234567890" + def test_run_with_complex_jinja2_template(self, mock_graph_runtime_state, graph_init_params): """Test _run with complex Jinja2 template including loops and conditions.""" node_data = { @@ -263,7 +307,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -291,6 +335,69 @@ class TestTemplateTransformNode: assert mapping["node_123.var1"] == ["sys", "input1"] assert mapping["node_123.var2"] == ["sys", "input2"] + def test_extract_variable_selector_to_variable_mapping_accepts_validated_node_data(self): + node_data = TemplateTransformNodeData( + title="Test", + variables=[VariableSelector(variable="var1", value_selector=["sys", "input1"])], + template="{{ var1 }}", + ) + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + + def test_extract_variable_selector_to_variable_mapping_returns_empty_mapping_without_variables(self): + node_data = { + "title": "Test", + "template": "{{ missing }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {} + + def test_extract_variable_selector_to_variable_mapping_accepts_sequence_value_selectors(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ("sys", "input1")}, + {"variable": "empty_selector", "value_selector": ()}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == { + "node_123.var1": ["sys", "input1"], + "node_123.empty_selector": [], + } + + def test_extract_variable_selector_to_variable_mapping_ignores_invalid_entries(self): + node_data = { + "title": "Test", + "variables": [ + {"variable": "var1", "value_selector": ["sys", "input1"]}, + {"variable": "missing_selector"}, + ["not", "a", "mapping"], + {"variable": 1, "value_selector": ["sys", "input2"]}, + {"variable": "invalid_selector", "value_selector": ["sys", 2]}, + ], + "template": "{{ var1 }}", + } + + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={}, node_id="node_123", node_data=node_data + ) + + assert mapping == {"node_123.var1": ["sys", "input1"]} + def test_run_with_empty_variables(self, mock_graph_runtime_state, graph_init_params): """Test _run with no variables (static template).""" node_data = { @@ -307,7 +414,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -346,7 +453,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -375,7 +482,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() @@ -405,7 +512,7 @@ class TestTemplateTransformNode: config={"id": "test_node", "data": node_data}, graph_init_params=graph_init_params, graph_runtime_state=mock_graph_runtime_state, - template_renderer=mock_renderer, + jinja2_template_renderer=mock_renderer, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py new file mode 100644 index 0000000000..bd22a8e318 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/template_transform/test_template_transform_node.py @@ -0,0 +1,74 @@ +from unittest.mock import MagicMock + +import pytest +from graphon.nodes.base.entities import VariableSelector +from graphon.nodes.template_transform.template_transform_node import ( + DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH, + TemplateTransformNode, +) +from graphon.runtime import GraphRuntimeState + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from tests.workflow_test_utils import build_test_graph_init_params + +from .template_transform_node_spec import TestTemplateTransformNode # noqa: F401 + + +@pytest.fixture +def graph_init_params(): + return build_test_graph_init_params( + workflow_id="test_workflow", + graph_config={}, + tenant_id="test_tenant", + app_id="test_app", + user_id="test_user", + user_from=UserFrom.ACCOUNT, + invoke_from=InvokeFrom.DEBUGGER, + call_depth=0, + ) + + +@pytest.fixture +def mock_graph_runtime_state(): + mock_state = MagicMock(spec=GraphRuntimeState) + mock_state.variable_pool = MagicMock() + return mock_state + + +def test_node_uses_default_max_output_length_when_not_overridden(graph_init_params, mock_graph_runtime_state): + node = TemplateTransformNode( + id="test_node", + config={ + "id": "test_node", + "data": { + "title": "Template Transform", + "variables": [], + "template": "hello", + }, + }, + graph_init_params=graph_init_params, + graph_runtime_state=mock_graph_runtime_state, + jinja2_template_renderer=MagicMock(), + ) + + assert node._max_output_length == DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH + + +def test_extract_variable_selector_to_variable_mapping_accepts_mixed_valid_entries(): + mapping = TemplateTransformNode._extract_variable_selector_to_variable_mapping( + graph_config={"ignored": True}, + node_id="node_123", + node_data={ + "variables": [ + VariableSelector(variable="validated", value_selector=["sys", "input1"]), + {"variable": "raw", "value_selector": ("sys", "input2")}, + {"variable": "invalid_selector", "value_selector": ["sys", 3]}, + ["not", "a", "mapping"], + ] + }, + ) + + assert mapping == { + "node_123.validated": ["sys", "input1"], + "node_123.raw": ["sys", "input2"], + } diff --git a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py index 2b0205fb7b..e11ebf6eb8 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_base_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_base_node.py @@ -1,15 +1,16 @@ from collections.abc import Mapping import pytest +from graphon.entities import GraphInitParams +from graphon.entities.base_node_data import BaseNodeData +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.base.node import Node +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities import GraphInitParams -from dify_graph.entities.base_node_data import BaseNodeData -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.base.node import Node -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.node_runtime import resolve_dify_run_context +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +36,7 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) return init_params, runtime_state @@ -67,7 +68,7 @@ def test_node_hydrates_data_during_initialization(): assert node.node_data.foo == "bar" assert node.title == "Sample" - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == "account" assert dify_ctx.invoke_from == "debugger" @@ -80,7 +81,7 @@ def test_node_accepts_invoke_from_enum(): invoke_from=InvokeFrom.DEBUGGER, ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool(system_variables=SystemVariable(user_id="user", files=[]), user_inputs={}), + variable_pool=VariablePool(system_variables=build_system_variables(user_id="user", files=[]), user_inputs={}), start_at=0.0, ) @@ -91,7 +92,7 @@ def test_node_accepts_invoke_from_enum(): graph_runtime_state=runtime_state, ) - dify_ctx = node.require_dify_context() + dify_ctx = resolve_dify_run_context(node.run_context) assert dify_ctx.user_from == UserFrom.ACCOUNT assert dify_ctx.invoke_from == InvokeFrom.DEBUGGER assert node.get_run_context_value("missing") is None @@ -127,3 +128,29 @@ def test_base_node_data_keeps_dict_style_access_compatibility(): assert node_data["foo"] == "bar" assert node_data.get("foo") == "bar" assert node_data.get("missing", "fallback") == "fallback" + + +def test_node_hydration_preserves_compatibility_extra_fields(): + graph_config: dict[str, object] = {} + init_params, runtime_state = _build_context(graph_config) + node_config = NodeConfigDictAdapter.validate_python( + { + "id": "node-1", + "data": { + "type": BuiltinNodeTypes.ANSWER, + "title": "Sample", + "foo": "bar", + "compat_flag": True, + }, + } + ) + + node = _SampleNode( + id="node-1", + config=node_config, + graph_init_params=init_params, + graph_runtime_state=runtime_state, + ) + + assert node.node_data.foo == "bar" + assert node.node_data.get("compat_flag") is True diff --git a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py index 40754974c1..555ff0c945 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_document_extractor_node.py @@ -4,23 +4,23 @@ from unittest.mock import Mock, patch import pandas as pd import pytest from docx.oxml.text.paragraph import CT_P - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities import GraphInitParams -from dify_graph.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData -from dify_graph.nodes.document_extractor.node import ( +from graphon.entities import GraphInitParams +from graphon.enums import BuiltinNodeTypes, WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod +from graphon.node_events import NodeRunResult +from graphon.nodes.document_extractor import DocumentExtractorNode, DocumentExtractorNodeData +from graphon.nodes.document_extractor.node import ( _extract_text_from_docx, _extract_text_from_excel, _extract_text_from_pdf, _extract_text_from_plain_text, _normalize_docx_zip, ) -from dify_graph.variables import ArrayFileSegment -from dify_graph.variables.segments import ArrayStringSegment -from dify_graph.variables.variables import StringVariable +from graphon.variables import ArrayFileSegment +from graphon.variables.segments import ArrayStringSegment +from graphon.variables.variables import StringVariable + +from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom from tests.workflow_test_utils import build_test_graph_init_params @@ -183,14 +183,14 @@ def test_run_extract_text( mock_response.raise_for_status = Mock() document_extractor_node._http_client.get = Mock(return_value=mock_response) - monkeypatch.setattr("dify_graph.file.file_manager.download", mock_download) + monkeypatch.setattr("graphon.file.file_manager.download", mock_download) if mime_type == "application/pdf": mock_pdf_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_pdf", mock_pdf_extract) elif mime_type.startswith("application/vnd.openxmlformats"): mock_docx_extract = Mock(return_value=expected_text[0]) - monkeypatch.setattr("dify_graph.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) + monkeypatch.setattr("graphon.nodes.document_extractor.node._extract_text_from_docx", mock_docx_extract) result = document_extractor_node._run() diff --git a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py index c746a945fe..1b14f0ab13 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_if_else.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_if_else.py @@ -3,19 +3,18 @@ import uuid from unittest.mock import MagicMock, Mock import pytest +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.graph import Graph +from graphon.nodes.if_else.entities import IfElseNodeData +from graphon.nodes.if_else.if_else_node import IfElseNode +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.utils.condition.entities import Condition, SubCondition, SubVariableCondition +from graphon.variables import ArrayFileSegment -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.graph import Graph -from dify_graph.nodes.if_else.entities import IfElseNodeData -from dify_graph.nodes.if_else.if_else_node import IfElseNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.utils.condition.entities import Condition, SubCondition, SubVariableCondition -from dify_graph.variables import ArrayFileSegment +from core.workflow.system_variables import build_system_variables from extensions.ext_database import db from tests.workflow_test_utils import build_test_graph_init_params @@ -35,7 +34,7 @@ def test_execute_if_else_result_true(): ) # construct variable pool - pool = VariablePool(system_variables=SystemVariable(user_id="aaa", files=[]), user_inputs={}) + pool = VariablePool(system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}) pool.add(["start", "array_contains"], ["ab", "def"]) pool.add(["start", "array_not_contains"], ["ac", "def"]) pool.add(["start", "contains"], "cabcde") @@ -142,7 +141,7 @@ def test_execute_if_else_result_false(): # construct variable pool pool = VariablePool( - system_variables=SystemVariable(user_id="aaa", files=[]), + system_variables=build_system_variables(user_id="aaa", files=[]), user_inputs={}, environment_variables=[], ) @@ -253,7 +252,6 @@ def test_array_file_contains_file_name(): node.graph_runtime_state.variable_pool.get.return_value = ArrayFileSegment( value=[ File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="1", @@ -316,7 +314,7 @@ def test_execute_if_else_boolean_conditions(condition: Condition): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -371,7 +369,7 @@ def test_execute_if_else_boolean_false_conditions(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) @@ -440,7 +438,7 @@ def test_execute_if_else_boolean_cases_structure(): # construct variable pool with boolean values pool = VariablePool( - system_variables=SystemVariable(files=[], user_id="aaa"), + system_variables=build_system_variables(files=[], user_id="aaa"), ) pool.add(["start", "bool_true"], True) pool.add(["start", "bool_false"], False) diff --git a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py index 6ca72b64b2..d28c3e01e5 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_list_operator.py @@ -1,12 +1,9 @@ from unittest.mock import MagicMock import pytest - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.enums import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.nodes.list_operator.entities import ( +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.nodes.list_operator.entities import ( ExtractConfig, FilterBy, FilterCondition, @@ -15,9 +12,11 @@ from dify_graph.nodes.list_operator.entities import ( Order, OrderByConfig, ) -from dify_graph.nodes.list_operator.exc import InvalidKeyError -from dify_graph.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func -from dify_graph.variables import ArrayFileSegment +from graphon.nodes.list_operator.exc import InvalidKeyError +from graphon.nodes.list_operator.node import ListOperatorNode, _get_file_extract_string_func +from graphon.variables import ArrayFileSegment + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom @pytest.fixture @@ -72,7 +71,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image1.jpg", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related1", storage_key="", @@ -80,7 +78,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="document1.pdf", type=FileType.DOCUMENT, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related2", storage_key="", @@ -88,7 +85,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="image2.png", type=FileType.IMAGE, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related3", storage_key="", @@ -96,7 +92,6 @@ def test_filter_files_by_type(list_operator_node): File( filename="audio1.mp3", type=FileType.AUDIO, - tenant_id="tenant1", transfer_method=FileTransferMethod.LOCAL_FILE, related_id="related4", storage_key="", @@ -120,14 +115,12 @@ def test_filter_files_by_type(list_operator_node): { "filename": "document1.pdf", "type": FileType.DOCUMENT, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related2", }, { "filename": "image2.png", "type": FileType.IMAGE, - "tenant_id": "tenant1", "transfer_method": FileTransferMethod.LOCAL_FILE, "related_id": "related3", }, @@ -136,7 +129,6 @@ def test_filter_files_by_type(list_operator_node): for expected_file, result_file in zip(expected_files, result.outputs["result"].value): assert expected_file["filename"] == result_file.filename assert expected_file["type"] == result_file.type - assert expected_file["tenant_id"] == result_file.tenant_id assert expected_file["transfer_method"] == result_file.transfer_method assert expected_file["related_id"] == result_file.related_id @@ -144,7 +136,6 @@ def test_filter_files_by_type(list_operator_node): def test_get_file_extract_string_func(): # Create a File object file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename="test_file.txt", @@ -165,7 +156,6 @@ def test_get_file_extract_string_func(): # Test with empty values empty_file = File( - tenant_id="test_tenant", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, filename=None, diff --git a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py b/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py deleted file mode 100644 index 6372583839..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_loop_node.py +++ /dev/null @@ -1,52 +0,0 @@ -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.nodes.loop.entities import LoopNodeData -from dify_graph.nodes.loop.loop_node import LoopNode - - -def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None: - seen_configs: list[object] = [] - original_validate_python = NodeConfigDictAdapter.validate_python - - def record_validate_python(value: object): - seen_configs.append(value) - return original_validate_python(value) - - monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python) - - child_node_config = { - "id": "answer-node", - "data": { - "type": "answer", - "title": "Answer", - "answer": "", - "loop_id": "loop-node", - }, - } - - LoopNode._extract_variable_selector_to_variable_mapping( - graph_config={ - "nodes": [ - { - "id": "loop-node", - "data": { - "type": "loop", - "title": "Loop", - "loop_count": 1, - "break_conditions": [], - "logical_operator": "and", - }, - }, - child_node_config, - ], - "edges": [], - }, - node_id="loop-node", - node_data=LoopNodeData( - title="Loop", - loop_count=1, - break_conditions=[], - logical_operator="and", - ), - ) - - assert seen_configs == [child_node_config] diff --git a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py b/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py deleted file mode 100644 index c5a02e87e4..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/test_question_classifier_node.py +++ /dev/null @@ -1,125 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock - -from dify_graph.model_runtime.entities import ImagePromptMessageContent -from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer -from dify_graph.nodes.protocols import HttpClientProtocol -from dify_graph.nodes.question_classifier import ( - QuestionClassifierNode, - QuestionClassifierNodeData, -) -from tests.workflow_test_utils import build_test_graph_init_params - - -def test_init_question_classifier_node_data(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - "vision": {"enabled": True, "configs": {"variable_selector": ["image"], "detail": "low"}}, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == True - assert node_data.vision.configs.variable_selector == ["image"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.LOW - - -def test_init_question_classifier_node_data_without_vision_config(): - data = { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - "memory": { - "role_prefix": {"user": "Human:", "assistant": "AI:"}, - "window": {"enabled": True, "size": 5}, - "query_prompt_template": "Previous conversation:\n{history}\n\nHuman: {query}\nAI:", - }, - } - - node_data = QuestionClassifierNodeData.model_validate(data) - - assert node_data.query_variable_selector == ["id", "name"] - assert node_data.model.provider == "openai" - assert node_data.classes[0].id == "1" - assert node_data.instruction == "This is a test instruction" - assert node_data.memory is not None - assert node_data.memory.role_prefix is not None - assert node_data.memory.role_prefix.user == "Human:" - assert node_data.memory.role_prefix.assistant == "AI:" - assert node_data.memory.window.enabled == True - assert node_data.memory.window.size == 5 - assert node_data.memory.query_prompt_template == "Previous conversation:\n{history}\n\nHuman: {query}\nAI:" - assert node_data.vision.enabled == False - assert node_data.vision.configs.variable_selector == ["sys", "files"] - assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH - - -def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch): - node_data = QuestionClassifierNodeData.model_validate( - { - "title": "test classifier node", - "query_variable_selector": ["id", "name"], - "model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}}, - "classes": [{"id": "1", "name": "class 1"}], - "instruction": "This is a test instruction", - } - ) - template_renderer = MagicMock(spec=TemplateRenderer) - node = QuestionClassifierNode( - id="node-id", - config={"id": "node-id", "data": node_data.model_dump(mode="json")}, - graph_init_params=build_test_graph_init_params( - workflow_id="workflow-id", - graph_config={}, - tenant_id="tenant-id", - app_id="app-id", - user_id="user-id", - ), - graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()), - credentials_provider=MagicMock(spec=CredentialsProvider), - model_factory=MagicMock(spec=ModelFactory), - model_instance=MagicMock(), - http_client=MagicMock(spec=HttpClientProtocol), - llm_file_saver=MagicMock(), - template_renderer=template_renderer, - ) - fetch_prompt_messages = MagicMock(return_value=([], None)) - monkeypatch.setattr( - "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages", - fetch_prompt_messages, - ) - monkeypatch.setattr( - "dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema", - MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])), - ) - - node._calculate_rest_token( - node_data=node_data, - query="hello", - model_instance=MagicMock(stop=(), parameters={}), - context="", - ) - - assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer diff --git a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py index b8f0e25e91..833c303052 100644 --- a/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py +++ b/api/tests/unit_tests/core/workflow/nodes/test_start_node_json_object.py @@ -2,21 +2,24 @@ import json import time import pytest +from graphon.nodes.start.entities import StartNodeData +from graphon.nodes.start.start_node import StartNode +from graphon.runtime import GraphRuntimeState +from graphon.variables import build_segment, segment_to_variable +from graphon.variables.input_entities import VariableEntity, VariableEntityType +from graphon.variables.variables import Variable from pydantic import ValidationError as PydanticValidationError -from dify_graph.nodes.start.entities import StartNodeData -from dify_graph.nodes.start.start_node import StartNode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from tests.workflow_test_utils import build_test_graph_init_params +from core.workflow.system_variables import build_system_variables +from core.workflow.variable_prefixes import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def make_start_node(user_inputs, variables): - variable_pool = VariablePool( - system_variables=SystemVariable(), - user_inputs=user_inputs, - conversation_variables=[], + variable_pool = build_test_variable_pool( + variables=build_system_variables(), + node_id="start", + inputs=user_inputs, ) config = { @@ -232,3 +235,64 @@ def test_json_object_optional_variable_not_provided(): # Current implementation raises a validation error even when the variable is optional with pytest.raises(ValueError, match="profile is required in input form"): node._run() + + +def test_start_node_outputs_full_variable_pool_snapshot(): + variable_pool = build_test_variable_pool( + variables=[ + *build_system_variables(query="hello", workflow_run_id="run-123"), + _build_prefixed_variable(ENVIRONMENT_VARIABLE_NODE_ID, "API_KEY", "secret"), + _build_prefixed_variable(CONVERSATION_VARIABLE_NODE_ID, "session_id", "conversation-1"), + ], + node_id="start", + inputs={"profile": {"age": 20, "name": "Tom"}}, + ) + + config = { + "id": "start", + "data": StartNodeData( + title="Start", + variables=[ + VariableEntity( + variable="profile", + label="profile", + type=VariableEntityType.JSON_OBJECT, + required=True, + ) + ], + ).model_dump(), + } + + graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) + node = StartNode( + id="start", + config=config, + graph_init_params=build_test_graph_init_params( + workflow_id="wf", + graph_config={}, + tenant_id="tenant", + app_id="app", + user_id="u", + user_from="account", + invoke_from="debugger", + call_depth=0, + ), + graph_runtime_state=graph_runtime_state, + ) + + result = node._run() + + assert result.inputs == {"profile": {"age": 20, "name": "Tom"}} + assert result.outputs["profile"] == {"age": 20, "name": "Tom"} + assert result.outputs["sys.query"] == "hello" + assert result.outputs["sys.workflow_run_id"] == "run-123" + assert result.outputs["env.API_KEY"] == "secret" + assert result.outputs["conversation.session_id"] == "conversation-1" + + +def _build_prefixed_variable(node_id: str, name: str, value: object) -> Variable: + return segment_to_variable( + segment=build_segment(value), + selector=(node_id, name), + name=name, + ) diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py index 3cbd96dfef..1587014802 100644 --- a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node.py @@ -3,23 +3,55 @@ from __future__ import annotations import sys import types from collections.abc import Generator +from types import SimpleNamespace from typing import TYPE_CHECKING, Any -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest +from graphon.file import File, FileTransferMethod, FileType +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.node_events import StreamChunkEvent, StreamCompletedEvent +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables.segments import ArrayFileSegment -from core.tools.entities.tool_entities import ToolInvokeMessage -from core.tools.utils.message_transformer import ToolFileMessageTransformer -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.model_runtime.entities.llm_entities import LLMUsage -from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables.segments import ArrayFileSegment +from core.workflow.system_variables import build_system_variables from tests.workflow_test_utils import build_test_graph_init_params if TYPE_CHECKING: # pragma: no cover - imported for type checking only - from dify_graph.nodes.tool.tool_node import ToolNode + from graphon.nodes.tool.tool_node import ToolNode + + +class _StubToolRuntime: + def get_runtime(self, *, node_id: str, node_data: Any, variable_pool: Any) -> ToolRuntimeHandle: + raise NotImplementedError + + def get_runtime_parameters(self, *, tool_runtime: ToolRuntimeHandle) -> list[Any]: + return [] + + def invoke( + self, + *, + tool_runtime: ToolRuntimeHandle, + tool_parameters: dict[str, Any], + workflow_call_depth: int, + provider_name: str, + ) -> Generator[ToolRuntimeMessage, None, None]: + yield from () + + def get_usage(self, *, tool_runtime: ToolRuntimeHandle) -> LLMUsage: + return LLMUsage.empty_usage() + + def build_file_reference(self, *, mapping: dict[str, Any]) -> Any: + return mapping + + def resolve_provider_icons( + self, + *, + provider_name: str, + default_icon: str | None = None, + ) -> tuple[str | None, str | None]: + return default_icon, None @pytest.fixture @@ -31,8 +63,8 @@ def tool_node(monkeypatch) -> ToolNode: ops_stub.TraceTask = object # pragma: no cover - stub attribute monkeypatch.setitem(sys.modules, module_name, ops_stub) - from dify_graph.nodes.protocols import ToolFileManagerProtocol - from dify_graph.nodes.tool.tool_node import ToolNode + from graphon.nodes.protocols import ToolFileManagerProtocol + from graphon.nodes.tool.tool_node import ToolNode graph_config: dict[str, Any] = { "nodes": [ @@ -66,13 +98,14 @@ def tool_node(monkeypatch) -> ToolNode: call_depth=0, ) - variable_pool = VariablePool(system_variables=SystemVariable(user_id="user-id")) + variable_pool = VariablePool(system_variables=build_system_variables(user_id="user-id")) graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=0.0) config = graph_config["nodes"][0] # Provide a stub ToolFileManager to satisfy the updated ToolNode constructor tool_file_manager_factory = MagicMock(spec=ToolFileManagerProtocol) + runtime = _StubToolRuntime() node = ToolNode( id="node-instance", @@ -80,6 +113,7 @@ def tool_node(monkeypatch) -> ToolNode: graph_init_params=init_params, graph_runtime_state=graph_runtime_state, tool_file_manager_factory=tool_file_manager_factory, + runtime=runtime, ) return node @@ -93,29 +127,19 @@ def _collect_events(generator: Generator) -> tuple[list[Any], LLMUsage]: return events, stop.value -def _run_transform(tool_node: ToolNode, message: ToolInvokeMessage) -> tuple[list[Any], LLMUsage]: - def _identity_transform(messages, *_args, **_kwargs): - return messages - - tool_runtime = MagicMock() - with patch.object( - ToolFileMessageTransformer, "transform_tool_invoke_messages", side_effect=_identity_transform, autospec=True - ): - generator = tool_node._transform_message( - messages=iter([message]), - tool_info={"provider_type": "builtin", "provider_id": "provider"}, - parameters_for_log={}, - user_id="user-id", - tenant_id="tenant-id", - node_id=tool_node._node_id, - tool_runtime=tool_runtime, - ) - return _collect_events(generator) +def _run_transform(tool_node: ToolNode, message: ToolRuntimeMessage) -> tuple[list[Any], LLMUsage]: + generator = tool_node._transform_message( + messages=iter([message]), + tool_info={"provider_type": "builtin", "provider_id": "provider"}, + parameters_for_log={}, + node_id=tool_node._node_id, + tool_runtime=ToolRuntimeHandle(raw=object()), + ) + return _collect_events(generator) def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): file_obj = File( - tenant_id="tenant-id", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.TOOL_FILE, related_id="file-id", @@ -125,9 +149,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): size=123, storage_key="file-key", ) - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="/files/tools/file-id.pdf"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), meta={"file": file_obj}, ) @@ -150,9 +174,9 @@ def test_link_messages_with_file_populate_files_output(tool_node: ToolNode): def test_plain_link_messages_remain_links(tool_node: ToolNode): - message = ToolInvokeMessage( - type=ToolInvokeMessage.MessageType.LINK, - message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.LINK, + message=ToolRuntimeMessage.TextMessage(text="https://dify.ai"), meta=None, ) @@ -167,3 +191,35 @@ def test_plain_link_messages_remain_links(tool_node: ToolNode): files_segment = completed_events[0].node_run_result.outputs["files"] assert isinstance(files_segment, ArrayFileSegment) assert files_segment.value == [] + + +def test_image_link_messages_use_tool_file_id_metadata(tool_node: ToolNode): + file_obj = File( + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.TOOL_FILE, + related_id="file-id", + filename="demo.pdf", + extension=".pdf", + mime_type="application/pdf", + size=123, + storage_key="file-key", + ) + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.return_value = ( + None, + SimpleNamespace(mime_type="application/pdf"), + ) + tool_node._runtime.build_file_reference = MagicMock(return_value=file_obj) + message = ToolRuntimeMessage( + type=ToolRuntimeMessage.MessageType.IMAGE_LINK, + message=ToolRuntimeMessage.TextMessage(text="/files/tools/file-id.pdf"), + meta={"tool_file_id": "file-id"}, + ) + + events, _ = _run_transform(tool_node, message) + + tool_node._tool_file_manager_factory.get_file_generator_by_tool_file_id.assert_called_once_with("file-id") + completed_events = [event for event in events if isinstance(event, StreamCompletedEvent)] + assert len(completed_events) == 1 + files_segment = completed_events[0].node_run_result.outputs["files"] + assert isinstance(files_segment, ArrayFileSegment) + assert files_segment.value == [file_obj] diff --git a/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py new file mode 100644 index 0000000000..c4dfc5a179 --- /dev/null +++ b/api/tests/unit_tests/core/workflow/nodes/tool/test_tool_node_runtime.py @@ -0,0 +1,276 @@ +from __future__ import annotations + +import sys +import types +from types import SimpleNamespace +from unittest.mock import MagicMock, patch + +import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage +from graphon.nodes.tool.entities import ToolNodeData, ToolProviderType +from graphon.nodes.tool.exc import ToolRuntimeInvocationError +from graphon.nodes.tool_runtime_entities import ToolRuntimeHandle, ToolRuntimeMessage +from graphon.runtime import VariablePool + +from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler +from core.plugin.impl.exc import PluginDaemonClientSideError, PluginInvokeError +from core.tools.entities.tool_entities import ToolInvokeMessage +from core.tools.entities.tool_entities import ToolProviderType as CoreToolProviderType +from core.tools.errors import ToolInvokeError +from core.tools.tool_engine import ToolEngine +from core.tools.tool_manager import ToolManager +from core.tools.utils.message_transformer import ToolFileMessageTransformer +from core.workflow.node_runtime import DifyToolNodeRuntime +from core.workflow.system_variables import build_system_variables +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool + + +@pytest.fixture +def runtime(monkeypatch) -> DifyToolNodeRuntime: + module_name = "core.ops.ops_trace_manager" + if module_name not in sys.modules: + ops_stub = types.ModuleType(module_name) + ops_stub.TraceQueueManager = object # pragma: no cover - stub attribute + ops_stub.TraceTask = object # pragma: no cover - stub attribute + monkeypatch.setitem(sys.modules, module_name, ops_stub) + + init_params = build_test_graph_init_params( + workflow_id="workflow-id", + graph_config={"nodes": [], "edges": []}, + tenant_id="tenant-id", + app_id="app-id", + user_id="user-id", + user_from="account", + invoke_from="debugger", + call_depth=0, + ) + return DifyToolNodeRuntime(init_params.run_context) + + +def _build_tool_node_data() -> ToolNodeData: + return ToolNodeData.model_validate( + { + "type": "tool", + "title": "Tool", + "provider_id": "provider", + "provider_type": ToolProviderType.BUILT_IN, + "provider_name": "provider", + "tool_name": "lookup", + "tool_label": "Lookup", + "tool_configurations": {}, + "tool_parameters": {}, + } + ) + + +def test_invoke_creates_callback_and_converts_messages(runtime: DifyToolNodeRuntime) -> None: + core_message = ToolInvokeMessage( + type=ToolInvokeMessage.MessageType.LINK, + message=ToolInvokeMessage.TextMessage(text="https://dify.ai"), + meta=None, + ) + variable_pool: VariablePool = build_test_variable_pool( + variables=build_system_variables(conversation_id="conversation-id") + ) + workflow_tool = MagicMock() + + with ( + patch.object(ToolManager, "get_workflow_tool_runtime", return_value=workflow_tool), + patch.object(ToolEngine, "generic_invoke", return_value=iter([core_message])) as generic_invoke_mock, + patch.object( + ToolFileMessageTransformer, + "transform_tool_invoke_messages", + side_effect=lambda *, messages, **_: messages, + ) as transform_tool_messages, + ): + tool_runtime = runtime.get_runtime( + node_id="node-id", + node_data=_build_tool_node_data(), + variable_pool=variable_pool, + ) + messages = list( + runtime.invoke( + tool_runtime=tool_runtime, + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + ) + + assert not hasattr(tool_runtime, "conversation_id") + assert len(messages) == 1 + graph_message = messages[0] + assert graph_message.type == ToolRuntimeMessage.MessageType.LINK + assert isinstance(graph_message.message, ToolRuntimeMessage.TextMessage) + assert graph_message.message.text == "https://dify.ai" + + callback = generic_invoke_mock.call_args.kwargs["workflow_tool_callback"] + assert isinstance(callback, DifyWorkflowCallbackHandler) + assert generic_invoke_mock.call_args.kwargs["conversation_id"] == "conversation-id" + + transform_kwargs = transform_tool_messages.call_args.kwargs + assert transform_kwargs["conversation_id"] == "conversation-id" + + +def test_invoke_maps_plugin_errors_to_graph_errors(runtime: DifyToolNodeRuntime) -> None: + invoke_error = PluginInvokeError('{"error_type":"RateLimit","message":"too many"}') + + with patch.object(ToolEngine, "generic_invoke", side_effect=invoke_error): + with pytest.raises(ToolRuntimeInvocationError, match="An error occurred in the provider"): + runtime.invoke( + tool_runtime=ToolRuntimeHandle(raw=MagicMock()), + tool_parameters={}, + workflow_call_depth=0, + provider_name="provider", + ) + + +def test_get_usage_normalizes_dict_payload(runtime: DifyToolNodeRuntime) -> None: + usage_payload = LLMUsage.empty_usage().model_dump() + usage_payload["total_tokens"] = 42 + + usage = runtime.get_usage( + tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=usage_payload)), + ) + + assert usage.total_tokens == 42 + + +def test_get_runtime_converts_graph_provider_type_for_tool_manager(runtime: DifyToolNodeRuntime) -> None: + node_data = _build_tool_node_data() + + with patch.object(ToolManager, "get_workflow_tool_runtime", return_value=MagicMock()) as runtime_mock: + tool_runtime = runtime.get_runtime(node_id="node-id", node_data=node_data, variable_pool=None) + + assert not hasattr(tool_runtime, "conversation_id") + workflow_tool = runtime_mock.call_args.args[3] + assert workflow_tool.provider_type == CoreToolProviderType.BUILT_IN + + +def test_get_runtime_parameters_reads_required_flags(runtime: DifyToolNodeRuntime) -> None: + tool_runtime = ToolRuntimeHandle( + raw=SimpleNamespace( + get_merged_runtime_parameters=MagicMock( + return_value=[ + SimpleNamespace(name="city", required=True), + SimpleNamespace(name="country", required=False), + ] + ) + ) + ) + + parameters = runtime.get_runtime_parameters(tool_runtime=tool_runtime) + + assert [(parameter.name, parameter.required) for parameter in parameters] == [ + ("city", True), + ("country", False), + ] + + +def test_get_usage_returns_empty_usage_when_tool_has_no_usage(runtime: DifyToolNodeRuntime) -> None: + usage = runtime.get_usage(tool_runtime=ToolRuntimeHandle(raw=SimpleNamespace(latest_usage=None))) + + assert usage == LLMUsage.empty_usage() + + +@pytest.mark.parametrize( + ("payload", "expected_type"), + [ + (ToolInvokeMessage.JsonMessage(json_object={"ok": True}, suppress_output=True), ToolRuntimeMessage.JsonMessage), + (ToolInvokeMessage.BlobMessage(blob=b"bytes"), ToolRuntimeMessage.BlobMessage), + ( + ToolInvokeMessage.BlobChunkMessage( + id="blob-id", + sequence=1, + total_length=5, + blob=b"hello", + end=True, + ), + ToolRuntimeMessage.BlobChunkMessage, + ), + (ToolInvokeMessage.FileMessage(file_marker="marker"), ToolRuntimeMessage.FileMessage), + ( + ToolInvokeMessage.VariableMessage(variable_name="city", variable_value="Tokyo", stream=True), + ToolRuntimeMessage.VariableMessage, + ), + ( + ToolInvokeMessage.LogMessage( + id="log-id", + label="lookup", + status=ToolInvokeMessage.LogMessage.LogStatus.SUCCESS, + data={"count": 1}, + metadata={"source": "tool"}, + ), + ToolRuntimeMessage.LogMessage, + ), + ], +) +def test_convert_message_payload_supports_runtime_message_types( + runtime: DifyToolNodeRuntime, + payload: object, + expected_type: type[object], +) -> None: + message = runtime._convert_message_payload(payload) + + assert isinstance(message, expected_type) + + +def test_convert_message_payload_rejects_unknown_types(runtime: DifyToolNodeRuntime) -> None: + with pytest.raises(TypeError, match="unsupported tool message payload"): + runtime._convert_message_payload(object()) + + +def test_resolve_provider_icons_prefers_builtin_tool_icons(runtime: DifyToolNodeRuntime) -> None: + plugin = SimpleNamespace( + plugin_id="langgenius/tools", + name="search", + declaration=SimpleNamespace(icon={"plugin": "icon"}), + ) + builtin_tool = SimpleNamespace( + name="langgenius/tools/search", + icon={"builtin": "icon"}, + icon_dark={"builtin": "dark"}, + ) + + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[builtin_tool]), + ): + installer_cls.return_value.list_plugins.return_value = [plugin] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="langgenius/tools/search") + + assert icon == {"builtin": "icon"} + assert icon_dark == {"builtin": "dark"} + + +def test_resolve_provider_icons_returns_default_when_provider_is_unknown(runtime: DifyToolNodeRuntime) -> None: + with ( + patch("core.workflow.node_runtime.PluginInstaller") as installer_cls, + patch("core.workflow.node_runtime.BuiltinToolManageService.list_builtin_tools", return_value=[]), + ): + installer_cls.return_value.list_plugins.return_value = [] + + icon, icon_dark = runtime.resolve_provider_icons(provider_name="unknown", default_icon="fallback") + + assert icon == "fallback" + assert icon_dark is None + + +@pytest.mark.parametrize( + ("exc", "message"), + [ + (PluginDaemonClientSideError("bad request"), "Failed to invoke tool, error: bad request"), + (ToolInvokeError("broken"), "Failed to invoke tool provider: broken"), + (RuntimeError("unexpected"), "unexpected"), + ], +) +def test_map_invocation_exception_normalizes_runtime_errors( + runtime: DifyToolNodeRuntime, + exc: Exception, + message: str, +) -> None: + error = runtime._map_invocation_exception(exc, provider_name="provider") + + assert isinstance(error, ToolRuntimeInvocationError) + assert str(error) == message diff --git a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py index 9aeab0409e..952e798430 100644 --- a/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/trigger_plugin/test_trigger_event_node.py @@ -1,13 +1,14 @@ from collections.abc import Mapping +from graphon.entities import GraphInitParams +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState + from core.trigger.constants import TRIGGER_PLUGIN_NODE_TYPE from core.workflow.nodes.trigger_plugin.trigger_event_node import TriggerEventNode -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from tests.workflow_test_utils import build_test_graph_init_params +from core.workflow.system_variables import build_system_variables +from tests.workflow_test_utils import build_test_graph_init_params, build_test_variable_pool def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, GraphRuntimeState]: @@ -17,9 +18,10 @@ def _build_context(graph_config: Mapping[str, object]) -> tuple[GraphInitParams, invoke_from="debugger", ) runtime_state = GraphRuntimeState( - variable_pool=VariablePool( - system_variables=SystemVariable(user_id="user", files=[]), - user_inputs={"payload": "value"}, + variable_pool=build_test_variable_pool( + variables=build_system_variables(user_id="user", files=[]), + node_id="node-1", + inputs={"payload": "value"}, ), start_at=0.0, ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py deleted file mode 100644 index e69c05dc0b..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v1/test_variable_assigner_v1.py +++ /dev/null @@ -1,308 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.graph import Graph -from dify_graph.graph_events.node import NodeRunSucceededEvent -from dify_graph.nodes.variable_assigner.common import helpers as common_helpers -from dify_graph.nodes.variable_assigner.v1 import VariableAssignerNode -from dify_graph.nodes.variable_assigner.v1.node_data import WriteMode -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayStringVariable, StringVariable - -DEFAULT_NODE_ID = "node_id" - - -def test_overwrite_string_variable(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "over-write", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = StringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value="the first value", - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - # construct variable pool - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.OVER_WRITE, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == input_variable.value - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.value == "the second value" - assert got.to_object() == "the second value" - - -def test_append_variable_to_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "append", - "input_variable_selector": ["node_id", "test_string_variable"], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - input_variable = StringVariable( - id=str(uuid4()), - name="test_string_variable", - value="the second value", - ) - conversation_id = str(uuid.uuid4()) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - variable_pool.add( - [DEFAULT_NODE_ID, input_variable.name], - input_variable, - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.APPEND, - "input_variable_selector": [DEFAULT_NODE_ID, input_variable.name], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == ["the first value", "the second value"] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["the first value", "the second value"] - - -def test_clear_array(): - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": { - "type": "assigner", - "title": "Variable Assigner", - "assigned_variable_selector": ["conversation", "test_conversation_variable"], - "write_mode": "clear", - "input_variable_selector": [], - }, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["the first value"], - ) - - conversation_id = str(uuid.uuid4()) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id=conversation_id), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "assigned_variable_selector": ["conversation", conversation_variable.name], - "write_mode": WriteMode.CLEAR, - "input_variable_selector": [], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - events = list(node.run()) - succeeded_event = next(event for event in events if isinstance(event, NodeRunSucceededEvent)) - updated_variables = common_helpers.get_updated_variables(succeeded_event.node_run_result.process_data) - assert updated_variables is not None - assert updated_variables[0].name == conversation_variable.name - assert updated_variables[0].new_value == [] - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py deleted file mode 100644 index a7673c5a14..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_helpers.py +++ /dev/null @@ -1,22 +0,0 @@ -from dify_graph.nodes.variable_assigner.v2.enums import Operation -from dify_graph.nodes.variable_assigner.v2.helpers import is_input_value_valid -from dify_graph.variables import SegmentType - - -def test_is_input_value_valid_overwrite_array_string(): - # Valid cases - assert is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["hello", "world"] - ) - assert is_input_value_valid(variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[]) - - # Invalid cases - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value="not an array" - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=[1, 2, 3] - ) - assert not is_input_value_valid( - variable_type=SegmentType.ARRAY_STRING, operation=Operation.OVER_WRITE, value=["valid", 123, "invalid"] - ) diff --git a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py b/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py deleted file mode 100644 index 6874f3fef1..0000000000 --- a/api/tests/unit_tests/core/workflow/nodes/variable_assigner/v2/test_variable_assigner_v2.py +++ /dev/null @@ -1,451 +0,0 @@ -import time -import uuid -from uuid import uuid4 - -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom -from core.workflow.node_factory import DifyNodeFactory -from dify_graph.entities import GraphInitParams -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY -from dify_graph.graph import Graph -from dify_graph.nodes.variable_assigner.v2 import VariableAssignerNode -from dify_graph.nodes.variable_assigner.v2.enums import InputType, Operation -from dify_graph.runtime import GraphRuntimeState, VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import ArrayStringVariable - -DEFAULT_NODE_ID = "node_id" - - -def test_handle_item_directly(): - """Test the _handle_item method directly for remove operations.""" - # Create variables - variable1 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable1", - value=["first", "second", "third"], - ) - - variable2 = ArrayStringVariable( - id=str(uuid4()), - name="test_variable2", - value=["first", "second", "third"], - ) - - # Create a mock class with just the _handle_item method - class MockNode: - def _handle_item(self, *, variable, operation, value): - match operation: - case Operation.REMOVE_FIRST: - if not variable.value: - return variable.value - return variable.value[1:] - case Operation.REMOVE_LAST: - if not variable.value: - return variable.value - return variable.value[:-1] - - node = MockNode() - - # Test remove-first - result1 = node._handle_item( - variable=variable1, - operation=Operation.REMOVE_FIRST, - value=None, - ) - - # Test remove-last - result2 = node._handle_item( - variable=variable2, - operation=Operation.REMOVE_LAST, - value=None, - ) - - # Check the results - assert result1 == ["second", "third"] - assert result2 == ["first", "second"] - - -def test_remove_first_from_array(): - """Test removing the first element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - # Run the node - result = list(node.run()) - - # Completed run - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["second", "third"] - - -def test_remove_last_from_array(): - """Test removing the last element from an array.""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=["first", "second", "third"], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == ["first", "second"] - - -def test_remove_first_from_empty_array(): - """Test removing the first element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_FIRST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] - - -def test_remove_last_from_empty_array(): - """Test removing the last element from an empty array (should do nothing).""" - graph_config = { - "edges": [ - { - "id": "start-source-assigner-target", - "source": "start", - "target": "assigner", - }, - ], - "nodes": [ - {"data": {"type": "start", "title": "Start"}, "id": "start"}, - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - - conversation_variable = ArrayStringVariable( - id=str(uuid4()), - name="test_conversation_variable", - value=[], - selector=["conversation", "test_conversation_variable"], - ) - - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[conversation_variable], - ) - - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - graph = Graph.init(graph_config=graph_config, node_factory=node_factory, root_node_id="start") - - node_config = { - "id": "node_id", - "data": { - "title": "test", - "version": "2", - "items": [ - { - "variable_selector": ["conversation", conversation_variable.name], - "input_type": InputType.VARIABLE, - "operation": Operation.REMOVE_LAST, - "value": None, - } - ], - }, - } - - node = VariableAssignerNode( - id=str(uuid.uuid4()), - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - config=node_config, - ) - - list(node.run()) - - got = variable_pool.get(["conversation", conversation_variable.name]) - assert got is not None - assert got.to_object() == [] - - -def test_node_factory_creates_variable_assigner_node(): - graph_config = { - "edges": [], - "nodes": [ - { - "data": {"type": "assigner", "version": "2", "title": "Variable Assigner", "items": []}, - "id": "assigner", - }, - ], - } - - init_params = GraphInitParams( - workflow_id="1", - graph_config=graph_config, - run_context={ - DIFY_RUN_CONTEXT_KEY: { - "tenant_id": "1", - "app_id": "1", - "user_id": "1", - "user_from": UserFrom.ACCOUNT, - "invoke_from": InvokeFrom.DEBUGGER, - } - }, - call_depth=0, - ) - variable_pool = VariablePool( - system_variables=SystemVariable(conversation_id="conversation_id"), - user_inputs={}, - environment_variables=[], - conversation_variables=[], - ) - graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter()) - - node_factory = DifyNodeFactory( - graph_init_params=init_params, - graph_runtime_state=graph_runtime_state, - ) - - node = node_factory.create_node(graph_config["nodes"][0]) - - assert isinstance(node, VariableAssignerNode) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py index 6be5bb23e8..be18391b2c 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_entities.py @@ -324,7 +324,7 @@ def test_webhook_body_parameter_edge_cases(): def test_webhook_data_inheritance(): """Test WebhookData inherits from BaseNodeData correctly.""" - from dify_graph.entities.base_node_data import BaseNodeData + from graphon.entities.base_node_data import BaseNodeData # Test that WebhookData is a subclass of BaseNodeData assert issubclass(WebhookData, BaseNodeData) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py index ddf1af5a59..f1132af02b 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_exceptions.py @@ -1,4 +1,5 @@ import pytest +from graphon.entities.exc import BaseNodeError from core.workflow.nodes.trigger_webhook.exc import ( WebhookConfigError, @@ -6,7 +7,6 @@ from core.workflow.nodes.trigger_webhook.exc import ( WebhookNotFoundError, WebhookTimeoutError, ) -from dify_graph.entities.exc import BaseNodeError def test_webhook_node_error_inheritance(): diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py index 78dd7ce0f3..cccd3fb676 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_file_conversion.py @@ -8,7 +8,11 @@ when passing files to downstream LLM nodes. from unittest.mock import Mock, patch -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool + +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.workflow.nodes.trigger_webhook.entities import ( ContentType, Method, @@ -16,11 +20,8 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookData, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable +from core.workflow.system_variables import default_system_variables +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node( @@ -96,6 +97,18 @@ def create_test_file_dict( } +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="webhook-node-1", + inputs=inputs, + ) + + +def expected_factory_mapping(file_dict: dict) -> dict: + return {**file_dict, "upload_file_id": file_dict["related_id"]} + + def test_webhook_node_file_conversion_to_file_variable(): """Test that webhook node converts file dictionaries to FileVariable objects.""" # Create test file dictionary (as it comes from webhook service) @@ -111,9 +124,8 @@ def test_webhook_node_file_conversion_to_file_variable(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -122,14 +134,14 @@ def test_webhook_node_file_conversion_to_file_variable(): "image_upload": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory and variable factory + # Mock the file reference boundary and variable factory with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -153,8 +165,7 @@ def test_webhook_node_file_conversion_to_file_variable(): # Verify file factory was called with correct parameters mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) # Verify segment factory was called to create FileSegment @@ -184,16 +195,15 @@ def test_webhook_node_file_conversion_with_missing_files(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, # No files } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -219,9 +229,8 @@ def test_webhook_node_file_conversion_with_none_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -230,7 +239,7 @@ def test_webhook_node_file_conversion_with_none_file(): "file": None, }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -256,9 +265,8 @@ def test_webhook_node_file_conversion_with_non_dict_file(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -267,7 +275,7 @@ def test_webhook_node_file_conversion_with_non_dict_file(): "file": "not_a_dict", # Wrapped to match node expectation }, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -300,9 +308,8 @@ def test_webhook_node_file_conversion_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -315,13 +322,13 @@ def test_webhook_node_file_conversion_mixed_parameters(): "file_param": file_dict, }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -350,8 +357,7 @@ def test_webhook_node_file_conversion_mixed_parameters(): # Verify file conversion was called mock_file_factory.assert_called_once_with( - mapping=file_dict, - tenant_id="test-tenant", + mapping=expected_factory_mapping(file_dict), ) @@ -370,9 +376,8 @@ def test_webhook_node_different_file_types(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -383,13 +388,13 @@ def test_webhook_node_different_file_types(): "video": create_test_file_dict("video.mp4", "video"), }, } - }, + } ) node = create_webhook_node(data, variable_pool) with ( - patch("factories.file_factory.build_from_mapping") as mock_file_factory, + patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory, patch("core.workflow.nodes.trigger_webhook.node.build_segment_with_type") as mock_segment_factory, patch("core.workflow.nodes.trigger_webhook.node.FileVariable") as mock_file_variable, ): @@ -430,9 +435,8 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -441,7 +445,7 @@ def test_webhook_node_file_conversion_with_non_dict_wrapper(): "file": "just a string", }, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py index 139f65d6c3..34c66a4f9f 100644 --- a/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/webhook/test_webhook_node.py @@ -1,8 +1,13 @@ from unittest.mock import patch import pytest +from graphon.entities import GraphInitParams +from graphon.enums import WorkflowNodeExecutionStatus +from graphon.file import File, FileTransferMethod, FileType +from graphon.runtime import GraphRuntimeState, VariablePool +from graphon.variables import FileVariable, StringVariable -from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom +from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, InvokeFrom, UserFrom from core.trigger.constants import TRIGGER_WEBHOOK_NODE_TYPE from core.workflow.nodes.trigger_webhook.entities import ( ContentType, @@ -12,13 +17,8 @@ from core.workflow.nodes.trigger_webhook.entities import ( WebhookParameter, ) from core.workflow.nodes.trigger_webhook.node import TriggerWebhookNode -from dify_graph.entities.graph_init_params import DIFY_RUN_CONTEXT_KEY, GraphInitParams -from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus -from dify_graph.file import File, FileTransferMethod, FileType -from dify_graph.runtime.graph_runtime_state import GraphRuntimeState -from dify_graph.runtime.variable_pool import VariablePool -from dify_graph.system_variable import SystemVariable -from dify_graph.variables import FileVariable, StringVariable +from core.workflow.system_variables import default_system_variables +from tests.workflow_test_utils import build_test_variable_pool def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) -> TriggerWebhookNode: @@ -62,6 +62,14 @@ def create_webhook_node(webhook_data: WebhookData, variable_pool: VariablePool) return node +def build_webhook_variable_pool(inputs: dict) -> VariablePool: + return build_test_variable_pool( + variables=default_system_variables(), + node_id="1", + inputs=inputs, + ) + + def test_webhook_node_basic_initialization(): """Test basic webhook node initialization and configuration.""" data = WebhookData( @@ -76,10 +84,7 @@ def test_webhook_node_basic_initialization(): timeout=30, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, - ) + variable_pool = build_webhook_variable_pool({}) node = create_webhook_node(data, variable_pool) @@ -119,9 +124,8 @@ def test_webhook_node_run_with_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "Authorization": "Bearer token123", @@ -132,7 +136,7 @@ def test_webhook_node_run_with_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -155,9 +159,8 @@ def test_webhook_node_run_with_query_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": { @@ -167,7 +170,7 @@ def test_webhook_node_run_with_query_params(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -191,9 +194,8 @@ def test_webhook_node_run_with_body_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -205,7 +207,7 @@ def test_webhook_node_run_with_body_params(): }, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -222,7 +224,6 @@ def test_webhook_node_run_with_file_params(): """Test webhook node execution with file parameter extraction.""" # Create mock file objects file1 = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -232,7 +233,6 @@ def test_webhook_node_run_with_file_params(): ) file2 = File( - tenant_id="1", type=FileType.DOCUMENT, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file2", @@ -250,9 +250,8 @@ def test_webhook_node_run_with_file_params(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, @@ -262,14 +261,14 @@ def test_webhook_node_run_with_file_params(): "document": file2.to_dict(), }, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -284,7 +283,6 @@ def test_webhook_node_run_with_file_params(): def test_webhook_node_run_mixed_parameters(): """Test webhook node execution with mixed parameter types.""" file_obj = File( - tenant_id="1", type=FileType.IMAGE, transfer_method=FileTransferMethod.LOCAL_FILE, related_id="file1", @@ -303,23 +301,22 @@ def test_webhook_node_run_mixed_parameters(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {"Authorization": "Bearer token"}, "query_params": {"version": "v1"}, "body": {"message": "Test message"}, "files": {"upload": file_obj.to_dict()}, } - }, + } ) node = create_webhook_node(data, variable_pool) - # Mock the file factory to avoid DB-dependent validation on upload_file_id - with patch("factories.file_factory.build_from_mapping") as mock_file_factory: + # Mock the node's file reference boundary to avoid DB-dependent validation on upload_file_id + with patch.object(node._file_reference_factory, "build_from_mapping") as mock_file_factory: - def _to_file(mapping, tenant_id, config=None, strict_type_validation=False): + def _to_file(*, mapping): return File.model_validate(mapping) mock_file_factory.side_effect = _to_file @@ -343,10 +340,7 @@ def test_webhook_node_run_empty_webhook_data(): body=[WebhookBodyParameter(name="message", type="string", required=False)], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={}, # No webhook_data - ) + variable_pool = build_webhook_variable_pool({}) # No webhook_data node = create_webhook_node(data, variable_pool) result = node._run() @@ -369,9 +363,8 @@ def test_webhook_node_run_case_insensitive_headers(): ], ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": { "content-type": "application/json", # lowercase @@ -382,7 +375,7 @@ def test_webhook_node_run_case_insensitive_headers(): "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) @@ -399,12 +392,11 @@ def test_webhook_node_variable_pool_user_inputs(): data = WebhookData(title="Test Webhook") # Add some additional variables to the pool - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": {"headers": {}, "query_params": {}, "body": {}, "files": {}}, "other_var": "should_be_included", - }, + } ) variable_pool.add(["node1", "extra"], StringVariable(name="extra", value="extra_value")) @@ -430,16 +422,15 @@ def test_webhook_node_different_methods(method): method=method, ) - variable_pool = VariablePool( - system_variables=SystemVariable.default(), - user_inputs={ + variable_pool = build_webhook_variable_pool( + { "webhook_data": { "headers": {}, "query_params": {}, "body": {}, "files": {}, } - }, + } ) node = create_webhook_node(data, variable_pool) diff --git a/api/tests/unit_tests/core/workflow/test_enums.py b/api/tests/unit_tests/core/workflow/test_enums.py deleted file mode 100644 index e8ce6f60f7..0000000000 --- a/api/tests/unit_tests/core/workflow/test_enums.py +++ /dev/null @@ -1,41 +0,0 @@ -"""Tests for workflow pause related enums and constants.""" - -from dify_graph.enums import ( - WorkflowExecutionStatus, -) - - -class TestWorkflowExecutionStatus: - """Test WorkflowExecutionStatus enum.""" - - def test_is_ended_method(self): - """Test is_ended method for different statuses.""" - # Test ended statuses - ended_statuses = [ - WorkflowExecutionStatus.SUCCEEDED, - WorkflowExecutionStatus.FAILED, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED, - WorkflowExecutionStatus.STOPPED, - ] - - for status in ended_statuses: - assert status.is_ended(), f"{status} should be considered ended" - - # Test non-ended statuses - non_ended_statuses = [ - WorkflowExecutionStatus.SCHEDULED, - WorkflowExecutionStatus.RUNNING, - WorkflowExecutionStatus.PAUSED, - ] - - for status in non_ended_statuses: - assert not status.is_ended(), f"{status} should not be considered ended" - - def test_ended_values(self): - """Test ended_values returns the expected status values.""" - assert set(WorkflowExecutionStatus.ended_values()) == { - WorkflowExecutionStatus.SUCCEEDED.value, - WorkflowExecutionStatus.FAILED.value, - WorkflowExecutionStatus.PARTIAL_SUCCEEDED.value, - WorkflowExecutionStatus.STOPPED.value, - } diff --git a/api/tests/unit_tests/core/workflow/test_human_input_compat.py b/api/tests/unit_tests/core/workflow/test_human_input_compat.py new file mode 100644 index 0000000000..cd41c43e4a --- /dev/null +++ b/api/tests/unit_tests/core/workflow/test_human_input_compat.py @@ -0,0 +1,184 @@ +from types import SimpleNamespace + +from graphon.enums import BuiltinNodeTypes +from pydantic import BaseModel + +from core.workflow.human_input_compat import ( + DeliveryMethodType, + EmailDeliveryConfig, + EmailDeliveryMethod, + EmailRecipients, + WebAppDeliveryMethod, + _WebAppDeliveryConfig, + is_human_input_webapp_enabled, + normalize_human_input_node_data_for_graph, + normalize_node_config_for_graph, + normalize_node_data_for_graph, + parse_human_input_delivery_methods, +) + + +def test_email_delivery_config_helpers_render_and_sanitize_text() -> None: + variable_pool = SimpleNamespace( + convert_template=lambda body: SimpleNamespace(text=body.replace("{{#node.value#}}", "42")) + ) + + rendered = EmailDeliveryConfig.render_body_template( + body="Open {{#url#}} and use {{#node.value#}}", + url="https://example.com", + variable_pool=variable_pool, + ) + sanitized = EmailDeliveryConfig.sanitize_subject("Hello\r\n Team") + html = EmailDeliveryConfig.render_markdown_body( + "**Hello** [mail](mailto:test@example.com)" + ) + + assert rendered == "Open https://example.com and use 42" + assert sanitized == "Hello alert(1) Team" + assert "Hello" in html + assert "", - "'; DROP TABLE users; --", - "../../../etc/passwd", - "\\x00\\x00", # null bytes - "A" * 10000, # very long input - ], - ) - def test_validate_api_key_auth_args_malicious_input(self, malicious_input): - """Test API key auth args validation - malicious input""" - args = self.mock_args.copy() - args["category"] = malicious_input - - # Verify parameter validator doesn't crash on malicious input - # Should validate normally rather than raising security-related exceptions - ApiKeyAuthService.validate_api_key_auth_args(args) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_database_error_handling(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - database error handling""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption - mock_encrypter.encrypt_token.return_value = "encrypted_key" - - # Mock database error - mock_session.commit.side_effect = Exception("Database error") - - with pytest.raises(Exception, match="Database error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - @patch("services.auth.api_key_auth_service.db.session") - def test_get_auth_credentials_invalid_json(self, mock_session): - """Test get auth credentials - invalid JSON""" - # Mock database returning invalid JSON - mock_binding = Mock() - mock_binding.credentials = "invalid json content" - mock_session.query.return_value.where.return_value.first.return_value = mock_binding - - with pytest.raises(json.JSONDecodeError): - ApiKeyAuthService.get_auth_credentials(self.tenant_id, self.category, self.provider) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - def test_create_provider_auth_factory_exception(self, mock_factory, mock_session): - """Test create provider auth - factory exception""" - # Mock factory raising exception - mock_factory.side_effect = Exception("Factory error") - - with pytest.raises(Exception, match="Factory error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.api_key_auth_service.ApiKeyAuthFactory") - @patch("services.auth.api_key_auth_service.encrypter") - def test_create_provider_auth_encryption_exception(self, mock_encrypter, mock_factory, mock_session): - """Test create provider auth - encryption exception""" - # Mock successful auth validation - mock_auth_instance = Mock() - mock_auth_instance.validate_credentials.return_value = True - mock_factory.return_value = mock_auth_instance - - # Mock encryption exception - mock_encrypter.encrypt_token.side_effect = Exception("Encryption error") - - with pytest.raises(Exception, match="Encryption error"): - ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args) - - def test_validate_api_key_auth_args_none_input(self): - """Test API key auth args validation - None input""" - with pytest.raises(TypeError): - ApiKeyAuthService.validate_api_key_auth_args(None) - - def test_validate_api_key_auth_args_dict_credentials_with_list_auth_type(self): - """Test API key auth args validation - dict credentials with list auth_type""" - args = self.mock_args.copy() - args["credentials"]["auth_type"] = ["api_key"] - - # Current implementation checks if auth_type exists and is truthy, list ["api_key"] is truthy - # So this should not raise exception, this test should pass - ApiKeyAuthService.validate_api_key_auth_args(args) diff --git a/api/tests/unit_tests/services/auth/test_auth_integration.py b/api/tests/unit_tests/services/auth/test_auth_integration.py deleted file mode 100644 index 3832a0b8b2..0000000000 --- a/api/tests/unit_tests/services/auth/test_auth_integration.py +++ /dev/null @@ -1,231 +0,0 @@ -""" -API Key Authentication System Integration Tests -""" - -import json -from concurrent.futures import ThreadPoolExecutor -from unittest.mock import Mock, patch - -import httpx -import pytest - -from services.auth.api_key_auth_factory import ApiKeyAuthFactory -from services.auth.api_key_auth_service import ApiKeyAuthService -from services.auth.auth_type import AuthType - - -class TestAuthIntegration: - def setup_method(self): - self.tenant_id_1 = "tenant_123" - self.tenant_id_2 = "tenant_456" # For multi-tenant isolation testing - self.category = "search" - - # Realistic authentication configurations - self.firecrawl_credentials = {"auth_type": "bearer", "config": {"api_key": "fc_test_key_123"}} - self.jina_credentials = {"auth_type": "bearer", "config": {"api_key": "jina_test_key_456"}} - self.watercrawl_credentials = {"auth_type": "x-api-key", "config": {"api_key": "wc_test_key_789"}} - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") - def test_end_to_end_auth_flow(self, mock_encrypt, mock_http, mock_session): - """Test complete authentication flow: request → validation → encryption → storage""" - mock_http.return_value = self._create_success_response() - mock_encrypt.return_value = "encrypted_fc_test_key_123" - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - - mock_http.assert_called_once() - call_args = mock_http.call_args - assert "https://api.firecrawl.dev/v1/crawl" in call_args[0][0] - assert call_args[1]["headers"]["Authorization"] == "Bearer fc_test_key_123" - - mock_encrypt.assert_called_once_with(self.tenant_id_1, "fc_test_key_123") - mock_session.add.assert_called_once() - mock_session.commit.assert_called_once() - - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_cross_component_integration(self, mock_http): - """Test factory → provider → HTTP call integration""" - mock_http.return_value = self._create_success_response() - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - result = factory.validate_credentials() - - assert result is True - mock_http.assert_called_once() - - @patch("services.auth.api_key_auth_service.db.session") - def test_multi_tenant_isolation(self, mock_session): - """Ensure complete tenant data isolation""" - tenant1_binding = self._create_mock_binding(self.tenant_id_1, AuthType.FIRECRAWL, self.firecrawl_credentials) - tenant2_binding = self._create_mock_binding(self.tenant_id_2, AuthType.JINA, self.jina_credentials) - - mock_session.scalars.return_value.all.return_value = [tenant1_binding] - result1 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_1) - - mock_session.scalars.return_value.all.return_value = [tenant2_binding] - result2 = ApiKeyAuthService.get_provider_auth_list(self.tenant_id_2) - - assert len(result1) == 1 - assert result1[0].tenant_id == self.tenant_id_1 - assert len(result2) == 1 - assert result2[0].tenant_id == self.tenant_id_2 - - @patch("services.auth.api_key_auth_service.db.session") - def test_cross_tenant_access_prevention(self, mock_session): - """Test prevention of cross-tenant credential access""" - mock_session.query.return_value.where.return_value.first.return_value = None - - result = ApiKeyAuthService.get_auth_credentials(self.tenant_id_2, self.category, AuthType.FIRECRAWL) - - assert result is None - - def test_sensitive_data_protection(self): - """Ensure API keys don't leak to logs""" - credentials_with_secrets = { - "auth_type": "bearer", - "config": {"api_key": "super_secret_key_do_not_log", "secret": "another_secret"}, - } - - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, credentials_with_secrets) - factory_str = str(factory) - - assert "super_secret_key_do_not_log" not in factory_str - assert "another_secret" not in factory_str - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - @patch("services.auth.api_key_auth_service.encrypter.encrypt_token") - def test_concurrent_creation_safety(self, mock_encrypt, mock_http, mock_session): - """Test concurrent authentication creation safety""" - mock_http.return_value = self._create_success_response() - mock_encrypt.return_value = "encrypted_key" - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - - results = [] - exceptions = [] - - def create_auth(): - try: - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - results.append("success") - except Exception as e: - exceptions.append(e) - - with ThreadPoolExecutor(max_workers=5) as executor: - futures = [executor.submit(create_auth) for _ in range(5)] - for future in futures: - future.result() - - assert len(results) == 5 - assert len(exceptions) == 0 - assert mock_session.add.call_count == 5 - assert mock_session.commit.call_count == 5 - - @pytest.mark.parametrize( - "invalid_input", - [ - None, # Null input - {}, # Empty dictionary - missing required fields - {"auth_type": "bearer"}, # Missing config section - {"auth_type": "bearer", "config": {}}, # Missing api_key - ], - ) - def test_invalid_input_boundary(self, invalid_input): - """Test boundary handling for invalid inputs""" - with pytest.raises((ValueError, KeyError, TypeError, AttributeError)): - ApiKeyAuthFactory(AuthType.FIRECRAWL, invalid_input) - - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_http_error_handling(self, mock_http): - """Test proper HTTP error handling""" - mock_response = Mock() - mock_response.status_code = 401 - mock_response.text = '{"error": "Unauthorized"}' - mock_response.raise_for_status.side_effect = httpx.HTTPError("Unauthorized") - mock_http.return_value = mock_response - - # PT012: Split into single statement for pytest.raises - factory = ApiKeyAuthFactory(AuthType.FIRECRAWL, self.firecrawl_credentials) - with pytest.raises((httpx.HTTPError, Exception)): - factory.validate_credentials() - - @patch("services.auth.api_key_auth_service.db.session") - @patch("services.auth.firecrawl.firecrawl.httpx.post") - def test_network_failure_recovery(self, mock_http, mock_session): - """Test system recovery from network failures""" - mock_http.side_effect = httpx.RequestError("Network timeout") - mock_session.add = Mock() - mock_session.commit = Mock() - - args = {"category": self.category, "provider": AuthType.FIRECRAWL, "credentials": self.firecrawl_credentials} - - with pytest.raises(httpx.RequestError): - ApiKeyAuthService.create_provider_auth(self.tenant_id_1, args) - - mock_session.commit.assert_not_called() - - @pytest.mark.parametrize( - ("provider", "credentials"), - [ - (AuthType.FIRECRAWL, {"auth_type": "bearer", "config": {"api_key": "fc_key"}}), - (AuthType.JINA, {"auth_type": "bearer", "config": {"api_key": "jina_key"}}), - (AuthType.WATERCRAWL, {"auth_type": "x-api-key", "config": {"api_key": "wc_key"}}), - ], - ) - def test_all_providers_factory_creation(self, provider, credentials): - """Test factory creation for all supported providers""" - auth_class = ApiKeyAuthFactory.get_apikey_auth_factory(provider) - assert auth_class is not None - - factory = ApiKeyAuthFactory(provider, credentials) - assert factory.auth is not None - - def _create_success_response(self, status_code=200): - """Create successful HTTP response mock""" - mock_response = Mock() - mock_response.status_code = status_code - mock_response.json.return_value = {"status": "success"} - mock_response.raise_for_status.return_value = None - return mock_response - - def _create_mock_binding(self, tenant_id: str, provider: str, credentials: dict) -> Mock: - """Create realistic database binding mock""" - mock_binding = Mock() - mock_binding.id = f"binding_{provider}_{tenant_id}" - mock_binding.tenant_id = tenant_id - mock_binding.category = self.category - mock_binding.provider = provider - mock_binding.credentials = json.dumps(credentials, ensure_ascii=False) - mock_binding.disabled = False - - mock_binding.created_at = Mock() - mock_binding.created_at.timestamp.return_value = 1640995200 - mock_binding.updated_at = Mock() - mock_binding.updated_at.timestamp.return_value = 1640995200 - - return mock_binding - - def test_integration_coverage_validation(self): - """Validate integration test coverage meets quality standards""" - core_scenarios = { - "business_logic": ["end_to_end_auth_flow", "cross_component_integration"], - "security": ["multi_tenant_isolation", "cross_tenant_access_prevention", "sensitive_data_protection"], - "reliability": ["concurrent_creation_safety", "network_failure_recovery"], - "compatibility": ["all_providers_factory_creation"], - "boundaries": ["invalid_input_boundary", "http_error_handling"], - } - - total_scenarios = sum(len(scenarios) for scenarios in core_scenarios.values()) - assert total_scenarios >= 10 - - security_tests = core_scenarios["security"] - assert "multi_tenant_isolation" in security_tests - assert "sensitive_data_protection" in security_tests - assert True diff --git a/api/tests/unit_tests/services/dataset_service_test_helpers.py b/api/tests/unit_tests/services/dataset_service_test_helpers.py new file mode 100644 index 0000000000..ef73bc0e01 --- /dev/null +++ b/api/tests/unit_tests/services/dataset_service_test_helpers.py @@ -0,0 +1,455 @@ +"""Shared helpers for dataset_service unit tests. + +These factories and lightweight builders are reused across the dataset, +document, and segment service test modules that exercise +``api/services/dataset_service.py``. +""" + +import json +from types import SimpleNamespace +from unittest.mock import MagicMock, Mock, create_autospec, patch + +import pytest +from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType +from werkzeug.exceptions import Forbidden, NotFound + +from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError +from core.rag.index_processor.constant.built_in_field import BuiltInField +from core.rag.index_processor.constant.index_type import IndexStructureType +from core.rag.retrieval.retrieval_methods import RetrievalMethod +from enums.cloud_plan import CloudPlan +from models import Account, TenantAccountRole +from models.dataset import ( + ChildChunk, + Dataset, + DatasetPermissionEnum, + DatasetProcessRule, + Document, + DocumentSegment, +) +from models.model import UploadFile +from services.dataset_service import ( + DatasetCollectionBindingService, + DatasetPermissionService, + DatasetService, + DocumentService, + SegmentService, +) +from services.entities.knowledge_entities.knowledge_entities import ( + ChildChunkUpdateArgs, + DataSource, + FileInfo, + InfoList, + KnowledgeConfig, + NotionIcon, + NotionInfo, + NotionPage, + PreProcessingRule, + ProcessRule, + RerankingModel, + RetrievalModel, + Rule, + Segmentation, + SegmentUpdateArgs, + WebsiteInfo, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + IconInfo as PipelineIconInfo, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + KnowledgeConfiguration, + RagPipelineDatasetCreateEntity, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + RerankingModelConfig as RagPipelineRerankingModelConfig, +) +from services.entities.knowledge_entities.rag_pipeline_entities import ( + RetrievalSetting as RagPipelineRetrievalSetting, +) +from services.errors.account import NoPermissionError +from services.errors.chunk import ChildChunkDeleteIndexError, ChildChunkIndexingError +from services.errors.dataset import DatasetNameDuplicateError +from services.errors.document import DocumentIndexingError +from services.errors.file import FileNotExistsError + +__all__ = [ + "Account", + "BuiltInField", + "ChildChunk", + "ChildChunkDeleteIndexError", + "ChildChunkIndexingError", + "ChildChunkUpdateArgs", + "CloudPlan", + "DataSource", + "Dataset", + "DatasetCollectionBindingService", + "DatasetNameDuplicateError", + "DatasetPermissionEnum", + "DatasetPermissionService", + "DatasetProcessRule", + "DatasetService", + "DatasetServiceUnitDataFactory", + "Document", + "DocumentIndexingError", + "DocumentSegment", + "DocumentService", + "FileInfo", + "FileNotExistsError", + "Forbidden", + "IndexStructureType", + "InfoList", + "KnowledgeConfig", + "KnowledgeConfiguration", + "LLMBadRequestError", + "MagicMock", + "Mock", + "ModelFeature", + "ModelType", + "NoPermissionError", + "NotFound", + "NotionIcon", + "NotionInfo", + "NotionPage", + "PipelineIconInfo", + "PreProcessingRule", + "ProcessRule", + "ProviderTokenNotInitError", + "RagPipelineDatasetCreateEntity", + "RagPipelineRerankingModelConfig", + "RagPipelineRetrievalSetting", + "RerankingModel", + "RetrievalMethod", + "RetrievalModel", + "Rule", + "SegmentService", + "SegmentUpdateArgs", + "Segmentation", + "SimpleNamespace", + "TenantAccountRole", + "WebsiteInfo", + "_make_child_chunk", + "_make_dataset", + "_make_document", + "_make_features", + "_make_knowledge_configuration", + "_make_lock_context", + "_make_retrieval_model", + "_make_segment", + "_make_session_context", + "_make_upload_knowledge_config", + "create_autospec", + "json", + "patch", + "pytest", +] + + +def _make_session_context(session: MagicMock) -> MagicMock: + """Wrap a mocked session in a context manager.""" + context_manager = MagicMock() + context_manager.__enter__.return_value = session + context_manager.__exit__.return_value = False + return context_manager + + +class DatasetServiceUnitDataFactory: + """Factory for lightweight doubles used across dataset service tests.""" + + @staticmethod + def create_dataset_mock( + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + *, + permission: str = DatasetPermissionEnum.ALL_TEAM, + created_by: str = "user-123", + indexing_technique: str = "economy", + embedding_model_provider: str = "provider", + embedding_model: str = "model", + built_in_field_enabled: bool = False, + doc_form: str | None = "text_model", + enable_api: bool = False, + summary_index_setting: dict | None = None, + **kwargs, + ) -> Mock: + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.permission = permission + dataset.created_by = created_by + dataset.indexing_technique = indexing_technique + dataset.embedding_model_provider = embedding_model_provider + dataset.embedding_model = embedding_model + dataset.built_in_field_enabled = built_in_field_enabled + dataset.doc_form = doc_form + dataset.enable_api = enable_api + dataset.updated_by = None + dataset.updated_at = None + dataset.summary_index_setting = summary_index_setting + for key, value in kwargs.items(): + setattr(dataset, key, value) + return dataset + + @staticmethod + def create_user_mock( + user_id: str = "user-123", + tenant_id: str = "tenant-123", + role: str = TenantAccountRole.OWNER, + **kwargs, + ) -> SimpleNamespace: + user = SimpleNamespace( + id=user_id, + current_tenant_id=tenant_id, + current_role=role, + ) + for key, value in kwargs.items(): + setattr(user, key, value) + return user + + @staticmethod + def create_document_mock( + document_id: str = "doc-123", + dataset_id: str = "dataset-123", + tenant_id: str = "tenant-123", + *, + indexing_status: str = "completed", + is_paused: bool = False, + archived: bool = False, + enabled: bool = True, + data_source_type: str = "upload_file", + data_source_info_dict: dict | None = None, + data_source_info: str | None = None, + doc_form: str = "text_model", + need_summary: bool = True, + position: int = 0, + doc_metadata: dict | None = None, + name: str = "Document", + **kwargs, + ) -> Mock: + document = Mock(spec=Document) + document.id = document_id + document.dataset_id = dataset_id + document.tenant_id = tenant_id + document.indexing_status = indexing_status + document.is_paused = is_paused + document.paused_by = None + document.paused_at = None + document.archived = archived + document.enabled = enabled + document.data_source_type = data_source_type + document.data_source_info_dict = data_source_info_dict or {} + document.data_source_info = data_source_info + document.doc_form = doc_form + document.need_summary = need_summary + document.position = position + document.doc_metadata = doc_metadata + document.name = name + for key, value in kwargs.items(): + setattr(document, key, value) + return document + + @staticmethod + def create_upload_file_mock(file_id: str = "file-123", name: str = "upload.txt") -> Mock: + upload_file = Mock(spec=UploadFile) + upload_file.id = file_id + upload_file.name = name + return upload_file + + +_UNSET = object() + + +def _make_lock_context() -> MagicMock: + context_manager = MagicMock() + context_manager.__enter__.return_value = None + context_manager.__exit__.return_value = False + return context_manager + + +def _make_features(*, enabled: bool, plan: str = CloudPlan.PROFESSIONAL) -> SimpleNamespace: + return SimpleNamespace( + billing=SimpleNamespace( + enabled=enabled, + subscription=SimpleNamespace(plan=plan), + ), + documents_upload_quota=SimpleNamespace(limit=1000, size=0), + ) + + +def _make_dataset( + *, + dataset_id: str = "dataset-1", + tenant_id: str = "tenant-1", + data_source_type: str | None = None, + indexing_technique: str | None = "economy", + latest_process_rule=None, +) -> Mock: + dataset = Mock(spec=Dataset) + dataset.id = dataset_id + dataset.tenant_id = tenant_id + dataset.data_source_type = data_source_type + dataset.indexing_technique = indexing_technique + dataset.latest_process_rule = latest_process_rule + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + dataset.summary_index_setting = None + dataset.retrieval_model = None + dataset.collection_binding_id = None + return dataset + + +def _make_document( + *, + document_id: str = "doc-1", + dataset_id: str = "dataset-1", + tenant_id: str = "tenant-1", + batch: str = "batch-1", + doc_form: str = IndexStructureType.PARAGRAPH_INDEX, + word_count: int = 0, + name: str = "Document 1", + enabled: bool = True, + archived: bool = False, + indexing_status: str = "completed", + display_status: str = "available", +) -> Mock: + document = Mock(spec=Document) + document.id = document_id + document.dataset_id = dataset_id + document.tenant_id = tenant_id + document.batch = batch + document.doc_form = doc_form + document.word_count = word_count + document.name = name + document.enabled = enabled + document.archived = archived + document.indexing_status = indexing_status + document.display_status = display_status + document.data_source_type = "upload_file" + document.data_source_info = "{}" + document.completed_at = SimpleNamespace() + document.processing_started_at = "started" + document.parsing_completed_at = "parsed" + document.cleaning_completed_at = "cleaned" + document.splitting_completed_at = "split" + document.updated_at = None + document.created_from = None + document.dataset_process_rule_id = "process-rule-1" + return document + + +def _make_segment( + *, + segment_id: str = "segment-1", + content: str = "segment content", + word_count: int = 15, + enabled: bool = True, + keywords: list[str] | None = None, + index_node_id: str = "node-1", + dataset_id: str = "dataset-1", + document_id: str = "doc-1", +) -> Mock: + segment = Mock(spec=DocumentSegment) + segment.id = segment_id + segment.dataset_id = dataset_id + segment.document_id = document_id + segment.content = content + segment.word_count = word_count + segment.enabled = enabled + segment.keywords = keywords or [] + segment.answer = None + segment.index_node_id = index_node_id + segment.disabled_at = None + segment.disabled_by = None + segment.status = "completed" + segment.error = None + return segment + + +def _make_child_chunk() -> ChildChunk: + return ChildChunk( + id="child-a", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=1, + content="old content", + word_count=11, + created_by="user-1", + ) + + +def _make_upload_knowledge_config( + *, + original_document_id: str | None = None, + file_ids: list[str] | None = None, + process_rule: ProcessRule | None = None, + data_source: DataSource | object | None = _UNSET, +) -> KnowledgeConfig: + if data_source is _UNSET: + info_list = InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=file_ids) if file_ids is not None else None, + ) + data_source = DataSource(info_list=info_list) + + return KnowledgeConfig( + original_document_id=original_document_id, + indexing_technique="economy", + data_source=data_source, + process_rule=process_rule, + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + +def _make_retrieval_model( + *, + reranking_provider_name: str = "rerank-provider", + reranking_model_name: str = "rerank-model", +) -> RetrievalModel: + return RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name=reranking_provider_name, + reranking_model_name=reranking_model_name, + ), + reranking_mode="reranking_model", + top_k=4, + score_threshold_enabled=False, + ) + + +def _make_rag_pipeline_retrieval_setting() -> RagPipelineRetrievalSetting: + return RagPipelineRetrievalSetting( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + top_k=4, + score_threshold=0.5, + score_threshold_enabled=True, + reranking_mode="reranking_model", + reranking_enable=True, + reranking_model=RagPipelineRerankingModelConfig( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + ) + + +def _make_knowledge_configuration( + *, + chunk_structure: str = "paragraph", + indexing_technique: str = "high_quality", + embedding_model_provider: str = "provider", + embedding_model: str = "embedding-model", + keyword_number: int = 8, + summary_index_setting: dict | None = None, +) -> KnowledgeConfiguration: + return KnowledgeConfiguration( + chunk_structure=chunk_structure, + indexing_technique=indexing_technique, + embedding_model_provider=embedding_model_provider, + embedding_model=embedding_model, + keyword_number=keyword_number, + retrieval_model=_make_rag_pipeline_retrieval_setting(), + summary_index_setting=summary_index_setting, + ) diff --git a/api/tests/unit_tests/services/dataset_service_update_delete.py b/api/tests/unit_tests/services/dataset_service_update_delete.py index 424ac18870..62c39f96d3 100644 --- a/api/tests/unit_tests/services/dataset_service_update_delete.py +++ b/api/tests/unit_tests/services/dataset_service_update_delete.py @@ -592,7 +592,7 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings: patch( "services.dataset_service.current_user", create_autospec(Account, instance=True) ) as mock_current_user, - patch("services.dataset_service.ModelManager") as mock_model_manager, + patch("services.dataset_service.ModelManager.for_tenant") as mock_model_manager, patch( "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" ) as mock_get_binding, diff --git a/api/tests/unit_tests/services/document_service_validation.py b/api/tests/unit_tests/services/document_service_validation.py index 49fdc5cc9b..7c36e9d960 100644 --- a/api/tests/unit_tests/services/document_service_validation.py +++ b/api/tests/unit_tests/services/document_service_validation.py @@ -109,10 +109,10 @@ This test suite follows a comprehensive testing strategy that covers: from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.model_entities import ModelType from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError from core.rag.index_processor.constant.index_type import IndexStructureType, IndexTechniqueType -from dify_graph.model_runtime.entities.model_entities import ModelType from models.dataset import Dataset, DatasetProcessRule, Document from services.dataset_service import DatasetService, DocumentService from services.entities.knowledge_entities.knowledge_entities import ( @@ -431,7 +431,7 @@ class TestDatasetServiceCheckDatasetModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_dataset_model_setting_high_quality_success(self, mock_model_manager): @@ -580,7 +580,7 @@ class TestDatasetServiceCheckEmbeddingModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_embedding_model_setting_success(self, mock_model_manager): @@ -702,7 +702,7 @@ class TestDatasetServiceCheckRerankingModelSetting: Provides a mocked ModelManager that can be used to verify model instance retrieval and error handling. """ - with patch("services.dataset_service.ModelManager") as mock_manager: + with patch("services.dataset_service.ModelManager.for_tenant") as mock_manager: yield mock_manager def test_check_reranking_model_setting_success(self, mock_model_manager): diff --git a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py b/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py deleted file mode 100644 index b66111902c..0000000000 --- a/api/tests/unit_tests/services/enterprise/test_account_deletion_sync.py +++ /dev/null @@ -1,276 +0,0 @@ -"""Unit tests for account deletion synchronization. - -This test module verifies the enterprise account deletion sync functionality, -including Redis queuing, error handling, and community vs enterprise behavior. -""" - -from unittest.mock import MagicMock, patch - -import pytest -from redis import RedisError - -from services.enterprise.account_deletion_sync import ( - _queue_task, - sync_account_deletion, - sync_workspace_member_removal, -) - - -class TestQueueTask: - """Unit tests for the _queue_task helper function.""" - - @pytest.fixture - def mock_redis_client(self): - """Mock redis_client for testing.""" - with patch("services.enterprise.account_deletion_sync.redis_client") as mock_redis: - yield mock_redis - - @pytest.fixture - def mock_uuid(self): - """Mock UUID generation for predictable task IDs.""" - with patch("services.enterprise.account_deletion_sync.uuid.uuid4") as mock_uuid_gen: - mock_uuid_gen.return_value = MagicMock(hex="test-task-id-1234") - yield mock_uuid_gen - - def test_queue_task_success(self, mock_redis_client, mock_uuid): - """Test successful task queueing to Redis.""" - # Arrange - workspace_id = "ws-123" - member_id = "member-456" - source = "test_source" - - # Act - result = _queue_task(workspace_id=workspace_id, member_id=member_id, source=source) - - # Assert - assert result is True - mock_redis_client.lpush.assert_called_once() - - # Verify the task payload structure - call_args = mock_redis_client.lpush.call_args[0] - assert call_args[0] == "enterprise:member:sync:queue" - - import json - - task_data = json.loads(call_args[1]) - assert task_data["workspace_id"] == workspace_id - assert task_data["member_id"] == member_id - assert task_data["source"] == source - assert task_data["type"] == "sync_member_deletion_from_workspace" - assert task_data["retry_count"] == 0 - assert "task_id" in task_data - assert "created_at" in task_data - - def test_queue_task_redis_error(self, mock_redis_client, caplog): - """Test handling of Redis connection errors.""" - # Arrange - mock_redis_client.lpush.side_effect = RedisError("Connection failed") - - # Act - result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - assert "Failed to queue account deletion sync" in caplog.text - - def test_queue_task_type_error(self, mock_redis_client, caplog): - """Test handling of JSON serialization errors.""" - # Arrange - mock_redis_client.lpush.side_effect = TypeError("Cannot serialize") - - # Act - result = _queue_task(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - assert "Failed to queue account deletion sync" in caplog.text - - -class TestSyncWorkspaceMemberRemoval: - """Unit tests for sync_workspace_member_removal function.""" - - @pytest.fixture - def mock_queue_task(self): - """Mock _queue_task for testing.""" - with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: - mock_queue.return_value = True - yield mock_queue - - def test_sync_workspace_member_removal_enterprise_enabled(self, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is True.""" - # Arrange - workspace_id = "ws-123" - member_id = "member-456" - source = "workspace_member_removed" - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_workspace_member_removal(workspace_id=workspace_id, member_id=member_id, source=source) - - # Assert - assert result is True - mock_queue_task.assert_called_once_with(workspace_id=workspace_id, member_id=member_id, source=source) - - def test_sync_workspace_member_removal_enterprise_disabled(self, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is False (community edition).""" - # Arrange - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = False - - # Act - result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is True - mock_queue_task.assert_not_called() - - def test_sync_workspace_member_removal_queue_failure(self, mock_queue_task): - """Test handling of queue task failures.""" - # Arrange - mock_queue_task.return_value = False - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_workspace_member_removal(workspace_id="ws-123", member_id="member-456", source="test_source") - - # Assert - assert result is False - - -class TestSyncAccountDeletion: - """Unit tests for sync_account_deletion function.""" - - @pytest.fixture - def mock_db_session(self): - """Mock database session for testing.""" - with patch("services.enterprise.account_deletion_sync.db.session") as mock_session: - yield mock_session - - @pytest.fixture - def mock_queue_task(self): - """Mock _queue_task for testing.""" - with patch("services.enterprise.account_deletion_sync._queue_task") as mock_queue: - mock_queue.return_value = True - yield mock_queue - - def test_sync_account_deletion_enterprise_disabled(self, mock_db_session, mock_queue_task): - """Test sync when ENTERPRISE_ENABLED is False (community edition).""" - # Arrange - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = False - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is True - mock_db_session.query.assert_not_called() - mock_queue_task.assert_not_called() - - def test_sync_account_deletion_multiple_workspaces(self, mock_db_session, mock_queue_task): - """Test sync for account with multiple workspace memberships.""" - # Arrange - account_id = "acc-123" - - # Mock workspace joins - mock_join1 = MagicMock() - mock_join1.tenant_id = "tenant-1" - mock_join2 = MagicMock() - mock_join2.tenant_id = "tenant-2" - mock_join3 = MagicMock() - mock_join3.tenant_id = "tenant-3" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] - mock_db_session.query.return_value = mock_query - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id=account_id, source="account_deleted") - - # Assert - assert result is True - assert mock_queue_task.call_count == 3 - - # Verify each workspace was queued - mock_queue_task.assert_any_call(workspace_id="tenant-1", member_id=account_id, source="account_deleted") - mock_queue_task.assert_any_call(workspace_id="tenant-2", member_id=account_id, source="account_deleted") - mock_queue_task.assert_any_call(workspace_id="tenant-3", member_id=account_id, source="account_deleted") - - def test_sync_account_deletion_no_workspaces(self, mock_db_session, mock_queue_task): - """Test sync for account with no workspace memberships.""" - # Arrange - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [] - mock_db_session.query.return_value = mock_query - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is True - mock_queue_task.assert_not_called() - - def test_sync_account_deletion_partial_failure(self, mock_db_session, mock_queue_task): - """Test sync when some tasks fail to queue.""" - # Arrange - account_id = "acc-123" - - # Mock workspace joins - mock_join1 = MagicMock() - mock_join1.tenant_id = "tenant-1" - mock_join2 = MagicMock() - mock_join2.tenant_id = "tenant-2" - mock_join3 = MagicMock() - mock_join3.tenant_id = "tenant-3" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join1, mock_join2, mock_join3] - mock_db_session.query.return_value = mock_query - - # Mock queue_task to fail for second workspace - def queue_side_effect(workspace_id, member_id, source): - return workspace_id != "tenant-2" - - mock_queue_task.side_effect = queue_side_effect - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id=account_id, source="account_deleted") - - # Assert - assert result is False # Should return False if any task fails - assert mock_queue_task.call_count == 3 - - def test_sync_account_deletion_all_failures(self, mock_db_session, mock_queue_task): - """Test sync when all tasks fail to queue.""" - # Arrange - mock_join = MagicMock() - mock_join.tenant_id = "tenant-1" - - mock_query = MagicMock() - mock_query.filter_by.return_value.all.return_value = [mock_join] - mock_db_session.query.return_value = mock_query - - mock_queue_task.return_value = False - - with patch("services.enterprise.account_deletion_sync.dify_config") as mock_config: - mock_config.ENTERPRISE_ENABLED = True - - # Act - result = sync_account_deletion(account_id="acc-123", source="account_deleted") - - # Assert - assert result is False - mock_queue_task.assert_called_once() diff --git a/api/tests/unit_tests/services/external_dataset_service.py b/api/tests/unit_tests/services/external_dataset_service.py index afc3b29fca..a8ef35a0d0 100644 --- a/api/tests/unit_tests/services/external_dataset_service.py +++ b/api/tests/unit_tests/services/external_dataset_service.py @@ -545,7 +545,7 @@ class TestExternalDatasetServiceProcessExternalApi: params={}, ) - from dify_graph.nodes.http_request.exc import InvalidHttpMethodError + from graphon.nodes.http_request.exc import InvalidHttpMethodError with pytest.raises(InvalidHttpMethodError): ExternalDatasetService.process_external_api(settings, files=None) diff --git a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py b/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py deleted file mode 100644 index 5d21665f75..0000000000 --- a/api/tests/unit_tests/services/recommend_app/test_database_retrieval.py +++ /dev/null @@ -1,145 +0,0 @@ -from types import SimpleNamespace -from unittest.mock import MagicMock, patch - -from services.recommend_app.database.database_retrieval import DatabaseRecommendAppRetrieval -from services.recommend_app.recommend_app_type import RecommendAppType - - -class TestDatabaseRecommendAppRetrieval: - def test_get_type(self): - assert DatabaseRecommendAppRetrieval().get_type() == RecommendAppType.DATABASE - - def test_get_recommended_apps_delegates(self): - with patch.object( - DatabaseRecommendAppRetrieval, - "fetch_recommended_apps_from_db", - return_value={"recommended_apps": [], "categories": []}, - ) as mock_fetch: - result = DatabaseRecommendAppRetrieval().get_recommended_apps_and_categories("en-US") - mock_fetch.assert_called_once_with("en-US") - assert result == {"recommended_apps": [], "categories": []} - - def test_get_recommend_app_detail_delegates(self): - with patch.object( - DatabaseRecommendAppRetrieval, - "fetch_recommended_app_detail_from_db", - return_value={"id": "app-1"}, - ) as mock_fetch: - result = DatabaseRecommendAppRetrieval().get_recommend_app_detail("app-1") - mock_fetch.assert_called_once_with("app-1") - assert result == {"id": "app-1"} - - -class TestFetchRecommendedAppsFromDb: - def _make_recommended_app(self, app_id, category, is_public=True, has_site=True): - site = ( - SimpleNamespace( - description="desc", - copyright="copy", - privacy_policy="pp", - custom_disclaimer="cd", - ) - if has_site - else None - ) - app = ( - SimpleNamespace(is_public=is_public, site=site) - if is_public - else SimpleNamespace(is_public=False, site=site) - ) - return SimpleNamespace( - id=f"rec-{app_id}", - app=app, - app_id=app_id, - category=category, - position=1, - is_listed=True, - ) - - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_apps_and_sorted_categories(self, mock_db): - rec1 = self._make_recommended_app("a1", "writing") - rec2 = self._make_recommended_app("a2", "assistant") - mock_db.session.scalars.return_value.all.return_value = [rec1, rec2] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert len(result["recommended_apps"]) == 2 - assert result["categories"] == ["assistant", "writing"] - - @patch("services.recommend_app.database.database_retrieval.db") - def test_falls_back_to_default_language_when_empty(self, mock_db): - mock_db.session.scalars.return_value.all.side_effect = [ - [], - [self._make_recommended_app("a1", "chat")], - ] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("fr-FR") - - assert len(result["recommended_apps"]) == 1 - assert mock_db.session.scalars.call_count == 2 - - @patch("services.recommend_app.database.database_retrieval.db") - def test_skips_non_public_apps(self, mock_db): - rec = self._make_recommended_app("a1", "chat", is_public=False) - mock_db.session.scalars.return_value.all.return_value = [rec] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert result["recommended_apps"] == [] - - @patch("services.recommend_app.database.database_retrieval.db") - def test_skips_apps_without_site(self, mock_db): - rec = self._make_recommended_app("a1", "chat", has_site=False) - mock_db.session.scalars.return_value.all.return_value = [rec] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_apps_from_db("en-US") - - assert result["recommended_apps"] == [] - - -class TestFetchRecommendedAppDetailFromDb: - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_none_when_not_listed(self, mock_db): - mock_db.session.query.return_value.where.return_value.first.return_value = None - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result is None - - @patch("services.recommend_app.database.database_retrieval.AppDslService") - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_none_when_app_not_public(self, mock_db, mock_dsl): - rec_chain = MagicMock() - rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") - app_chain = MagicMock() - app_chain.where.return_value.first.return_value = SimpleNamespace(id="app-1", is_public=False) - mock_db.session.query.side_effect = [rec_chain, app_chain] - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result is None - - @patch("services.recommend_app.database.database_retrieval.AppDslService") - @patch("services.recommend_app.database.database_retrieval.db") - def test_returns_detail_on_success(self, mock_db, mock_dsl): - app_model = SimpleNamespace( - id="app-1", - name="My App", - icon="icon.png", - icon_background="#fff", - mode="chat", - is_public=True, - ) - rec_chain = MagicMock() - rec_chain.where.return_value.first.return_value = SimpleNamespace(app_id="app-1") - app_chain = MagicMock() - app_chain.where.return_value.first.return_value = app_model - mock_db.session.query.side_effect = [rec_chain, app_chain] - mock_dsl.export_dsl.return_value = "exported_yaml" - - result = DatabaseRecommendAppRetrieval.fetch_recommended_app_detail_from_db("app-1") - - assert result["id"] == "app-1" - assert result["name"] == "My App" - assert result["export_data"] == "exported_yaml" diff --git a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py b/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py deleted file mode 100644 index f9d901fca2..0000000000 --- a/api/tests/unit_tests/services/retention/conversation/test_messages_clean_service.py +++ /dev/null @@ -1,311 +0,0 @@ -import datetime -from unittest.mock import MagicMock, patch - -import pytest - -from services.retention.conversation.messages_clean_policy import ( - BillingDisabledPolicy, -) -from services.retention.conversation.messages_clean_service import MessagesCleanService - - -class TestMessagesCleanService: - @pytest.fixture(autouse=True) - def mock_db_engine(self): - with patch("services.retention.conversation.messages_clean_service.db") as mock_db: - mock_db.engine = MagicMock() - yield mock_db.engine - - @pytest.fixture - def mock_db_session(self, mock_db_engine): - with patch("services.retention.conversation.messages_clean_service.Session") as mock_session_cls: - mock_session = MagicMock() - mock_session_cls.return_value.__enter__.return_value = mock_session - yield mock_session - - @pytest.fixture - def mock_policy(self): - policy = MagicMock(spec=BillingDisabledPolicy) - return policy - - def test_run_calls_clean_messages(self, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - with patch.object(service, "_clean_messages_by_time_range") as mock_clean: - mock_clean.return_value = {"total_deleted": 5} - result = service.run() - assert result == {"total_deleted": 5} - mock_clean.assert_called_once() - - def test_clean_messages_by_time_range_basic(self, mock_db_session, mock_policy): - # Arrange - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - end_before=end_before, - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock( - rowcount=1 - ), # delete relations (this is wrong, relations delete doesn't use rowcount here, but execute) - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete relations - MagicMock(rowcount=1), # delete messages - MagicMock(all=lambda: []), # next batch empty - ] - - # Reset side_effect to be more robust - # The service calls session.execute for: - # 1. Fetch messages - # 2. Fetch apps - # 3. Batch delete relations (8 calls if IDs exist) - # 4. Delete messages - - mock_returns = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime(2024, 1, 1, 10, 0, 0))]), # fetch messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # fetch apps - ] - # 8 deletes for relations - mock_returns.extend([MagicMock() for _ in range(8)]) - # 1 delete for messages - mock_returns.append(MagicMock(rowcount=1)) - # Final fetch messages (empty) - mock_returns.append(MagicMock(all=lambda: [])) - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] - - # Act - with patch("services.retention.conversation.messages_clean_service.time.sleep"): - stats = service.run() - - # Assert - assert stats["total_messages"] == 1 - assert stats["total_deleted"] == 1 - assert stats["batches"] == 2 - - def test_clean_messages_by_time_range_with_start_from(self, mock_db_session, mock_policy): - start_from = datetime.datetime(2024, 1, 1, 0, 0, 0) - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - start_from=start_from, - end_before=end_before, - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: []), # No messages - ] - - stats = service.run() - assert stats["total_messages"] == 0 - - def test_clean_messages_by_time_range_with_cursor(self, mock_db_session, mock_policy): - # Test pagination with cursor - end_before = datetime.datetime(2024, 1, 1, 12, 0, 0) - service = MessagesCleanService( - policy=mock_policy, - end_before=end_before, - batch_size=1, - ) - - msg1_time = datetime.datetime(2024, 1, 1, 10, 0, 0) - msg2_time = datetime.datetime(2024, 1, 1, 11, 0, 0) - - mock_returns = [] - # Batch 1 - mock_returns.append(MagicMock(all=lambda: [("msg1", "app1", msg1_time)])) - mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - - # Batch 2 - mock_returns.append(MagicMock(all=lambda: [("msg2", "app1", msg2_time)])) - mock_returns.append(MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")])) - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - - # Batch 3 - mock_returns.append(MagicMock(all=lambda: [])) - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] # Simplified - - with patch("services.retention.conversation.messages_clean_service.time.sleep"): - stats = service.run() - - assert stats["batches"] == 3 - assert stats["total_messages"] == 2 - - def test_clean_messages_by_time_range_dry_run(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - dry_run=True, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock(all=lambda: []), # next batch empty - ] - mock_policy.filter_message_ids.return_value = ["msg1"] - - with patch("services.retention.conversation.messages_clean_service.random.sample") as mock_sample: - mock_sample.return_value = ["msg1"] - stats = service.run() - assert stats["filtered_messages"] == 1 - assert stats["total_deleted"] == 0 # Dry run - mock_sample.assert_called() - - def test_clean_messages_by_time_range_no_apps_found(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: []), # apps NOT found - MagicMock(all=lambda: []), # next batch empty - ] - - stats = service.run() - assert stats["total_messages"] == 1 - assert stats["total_deleted"] == 0 - - def test_clean_messages_by_time_range_no_app_ids(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: []), # next batch empty - ] - - # We need to successfully execute line 228 and 229, then return empty at 251. - # line 228: raw_messages = list(session.execute(msg_stmt).all()) - # line 251: app_ids = list({msg.app_id for msg in messages}) - - calls = [] - - def list_side_effect(arg): - calls.append(arg) - if len(calls) == 2: # This is the second call to list() in the loop - return [] - return list(arg) - - with patch("services.retention.conversation.messages_clean_service.list", side_effect=list_side_effect): - stats = service.run() - assert stats["batches"] == 2 - assert stats["total_messages"] == 1 - - def test_from_time_range_validation(self, mock_policy): - now = datetime.datetime.now() - # Test start_from >= end_before - with pytest.raises(ValueError, match="start_from .* must be less than end_before"): - MessagesCleanService.from_time_range(mock_policy, now, now) - - # Test batch_size <= 0 - with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): - MessagesCleanService.from_time_range(mock_policy, now - datetime.timedelta(days=1), now, batch_size=0) - - def test_from_time_range_success(self, mock_policy): - start = datetime.datetime(2024, 1, 1) - end = datetime.datetime(2024, 2, 1) - # Mock logger to avoid actual logging if needed, though it's fine - service = MessagesCleanService.from_time_range(mock_policy, start, end) - assert service._start_from == start - assert service._end_before == end - - def test_from_days_validation(self, mock_policy): - # Test days < 0 - with pytest.raises(ValueError, match="days .* must be greater than or equal to 0"): - MessagesCleanService.from_days(mock_policy, days=-1) - - # Test batch_size <= 0 - with pytest.raises(ValueError, match="batch_size .* must be greater than 0"): - MessagesCleanService.from_days(mock_policy, days=30, batch_size=0) - - def test_from_days_success(self, mock_policy): - with patch("services.retention.conversation.messages_clean_service.naive_utc_now") as mock_now: - fixed_now = datetime.datetime(2024, 6, 1) - mock_now.return_value = fixed_now - - service = MessagesCleanService.from_days(mock_policy, days=10) - assert service._start_from is None - assert service._end_before == fixed_now - datetime.timedelta(days=10) - - def test_clean_messages_by_time_range_no_messages_to_delete(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_db_session.execute.side_effect = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - MagicMock(all=lambda: []), # next batch empty - ] - mock_policy.filter_message_ids.return_value = [] # Policy says NO - - stats = service.run() - assert stats["total_messages"] == 1 - assert stats["filtered_messages"] == 0 - assert stats["total_deleted"] == 0 - - def test_batch_delete_message_relations_empty(self, mock_db_session): - MessagesCleanService._batch_delete_message_relations(mock_db_session, []) - mock_db_session.execute.assert_not_called() - - def test_batch_delete_message_relations_with_ids(self, mock_db_session): - MessagesCleanService._batch_delete_message_relations(mock_db_session, ["msg1", "msg2"]) - assert mock_db_session.execute.call_count == 8 # 8 tables to clean up - - def test_clean_messages_interval_from_env(self, mock_db_session, mock_policy): - service = MessagesCleanService( - policy=mock_policy, - end_before=datetime.datetime.now(), - batch_size=10, - ) - - mock_returns = [ - MagicMock(all=lambda: [("msg1", "app1", datetime.datetime.now())]), # messages - MagicMock(all=lambda: [MagicMock(id="app1", tenant_id="tenant1")]), # apps - ] - mock_returns.extend([MagicMock() for _ in range(8)]) # relations - mock_returns.append(MagicMock(rowcount=1)) # messages - mock_returns.append(MagicMock(all=lambda: [])) # next batch empty - - mock_db_session.execute.side_effect = mock_returns - mock_policy.filter_message_ids.return_value = ["msg1"] - - with patch( - "services.retention.conversation.messages_clean_service.dify_config.SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL", - 500, - ): - with patch("services.retention.conversation.messages_clean_service.time.sleep") as mock_sleep: - with patch("services.retention.conversation.messages_clean_service.random.uniform") as mock_uniform: - mock_uniform.return_value = 300.0 - service.run() - mock_uniform.assert_called_with(0, 500) - mock_sleep.assert_called_with(0.3) diff --git a/api/tests/unit_tests/services/segment_service.py b/api/tests/unit_tests/services/segment_service.py index 14af7f7119..f0a66a00d4 100644 --- a/api/tests/unit_tests/services/segment_service.py +++ b/api/tests/unit_tests/services/segment_service.py @@ -267,7 +267,7 @@ class TestSegmentServiceCreateSegment: patch( "services.dataset_service.VectorService.create_segments_vector", autospec=True ) as mock_vector_service, - patch("services.dataset_service.ModelManager", autospec=True) as mock_model_manager_class, + patch("services.dataset_service.ModelManager.for_tenant", autospec=True) as mock_model_manager_class, patch("services.dataset_service.helper.generate_text_hash", autospec=True) as mock_hash, patch("services.dataset_service.naive_utc_now", autospec=True) as mock_now, ): diff --git a/api/tests/unit_tests/services/test_app_dsl_service.py b/api/tests/unit_tests/services/test_app_dsl_service.py index 239e51119c..179518a5fa 100644 --- a/api/tests/unit_tests/services/test_app_dsl_service.py +++ b/api/tests/unit_tests/services/test_app_dsl_service.py @@ -4,13 +4,13 @@ from unittest.mock import MagicMock import pytest import yaml +from graphon.enums import BuiltinNodeTypes from core.trigger.constants import ( TRIGGER_PLUGIN_NODE_TYPE, TRIGGER_SCHEDULE_NODE_TYPE, TRIGGER_WEBHOOK_NODE_TYPE, ) -from dify_graph.enums import BuiltinNodeTypes from models import Account, AppMode from models.model import IconType from services import app_dsl_service diff --git a/api/tests/unit_tests/services/test_audio_service.py b/api/tests/unit_tests/services/test_audio_service.py index 35b288cf7c..175fd3ee01 100644 --- a/api/tests/unit_tests/services/test_audio_service.py +++ b/api/tests/unit_tests/services/test_audio_service.py @@ -214,7 +214,7 @@ def factory(): class TestAudioServiceASR: """Test speech-to-text (ASR) operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_success_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in CHAT mode.""" # Arrange @@ -237,10 +237,9 @@ class TestAudioServiceASR: # Assert assert result == {"text": "Transcribed text"} mock_model_instance.invoke_speech2text.assert_called_once() - call_args = mock_model_instance.invoke_speech2text.call_args - assert call_args.kwargs["user"] == "user-123" + mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_success_advanced_chat_mode(self, mock_model_manager_class, factory): """Test successful ASR transcription in ADVANCED_CHAT mode.""" # Arrange @@ -347,7 +346,7 @@ class TestAudioServiceASR: with pytest.raises(AudioTooLargeServiceError, match="Audio size larger than 30 mb"): AudioService.transcript_asr(app_model=app, file=file) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_asr_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that ASR raises error when no model instance is available.""" # Arrange @@ -370,7 +369,7 @@ class TestAudioServiceASR: class TestAudioServiceTTS: """Test text-to-speech (TTS) operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_text_success(self, mock_model_manager_class, factory): """Test successful TTS with text input.""" # Arrange @@ -398,15 +397,14 @@ class TestAudioServiceTTS: # Assert assert result == b"audio data" + mock_model_manager_class.assert_called_once_with(tenant_id=app.tenant_id, user_id="user-123") mock_model_instance.invoke_tts.assert_called_once_with( content_text="Hello world", - user="user-123", - tenant_id=app.tenant_id, voice="en-US-Neural", ) - @patch("services.audio_service.db.session") - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.db.session", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_message_id_success(self, mock_model_manager_class, mock_db_session, factory): """Test successful TTS with message ID.""" # Arrange @@ -445,7 +443,7 @@ class TestAudioServiceTTS: assert result == b"audio from message" mock_model_instance.invoke_tts.assert_called_once() - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_with_default_voice(self, mock_model_manager_class, factory): """Test TTS uses default voice when none specified.""" # Arrange @@ -475,7 +473,7 @@ class TestAudioServiceTTS: call_args = mock_model_instance.invoke_tts.call_args assert call_args.kwargs["voice"] == "default-voice" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_gets_first_available_voice_when_none_configured(self, mock_model_manager_class, factory): """Test TTS gets first available voice when none is configured.""" # Arrange @@ -506,7 +504,7 @@ class TestAudioServiceTTS: assert call_args.kwargs["voice"] == "auto-voice" @patch("services.audio_service.WorkflowService", autospec=True) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_workflow_mode_with_draft( self, mock_model_manager_class, mock_workflow_service_class, factory ): @@ -611,7 +609,7 @@ class TestAudioServiceTTS: # Assert assert result is None - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_raises_error_when_no_voices_available(self, mock_model_manager_class, factory): """Test that TTS raises error when no voices are available.""" # Arrange @@ -637,7 +635,7 @@ class TestAudioServiceTTS: class TestAudioServiceTTSVoices: """Test TTS voice listing operations.""" - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_success(self, mock_model_manager_class, factory): """Test successful retrieval of TTS voices.""" # Arrange @@ -662,7 +660,7 @@ class TestAudioServiceTTSVoices: assert result == expected_voices mock_model_instance.get_tts_voices.assert_called_once_with(language) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_raises_error_when_no_model_instance(self, mock_model_manager_class, factory): """Test that TTS voices raises error when no model instance is available.""" # Arrange @@ -677,7 +675,7 @@ class TestAudioServiceTTSVoices: with pytest.raises(ProviderNotSupportTextToSpeechServiceError): AudioService.transcript_tts_voices(tenant_id=tenant_id, language=language) - @patch("services.audio_service.ModelManager", autospec=True) + @patch("services.audio_service.ModelManager.for_tenant", autospec=True) def test_transcript_tts_voices_propagates_exceptions(self, mock_model_manager_class, factory): """Test that TTS voices propagates exceptions from model instance.""" # Arrange diff --git a/api/tests/unit_tests/services/test_conversation_service.py b/api/tests/unit_tests/services/test_conversation_service.py index 35157790ca..1bf4c0e172 100644 --- a/api/tests/unit_tests/services/test_conversation_service.py +++ b/api/tests/unit_tests/services/test_conversation_service.py @@ -6,13 +6,14 @@ Tests are organized by functionality and include edge cases, error handling, and both positive and negative test scenarios. """ -from datetime import datetime, timedelta +from datetime import timedelta from unittest.mock import MagicMock, Mock, create_autospec, patch import pytest from sqlalchemy import asc, desc from core.app.entities.app_invoke_entities import InvokeFrom +from libs.datetime_utils import naive_utc_now from libs.infinite_scroll_pagination import InfiniteScrollPagination from models import Account, ConversationVariable from models.enums import ConversationFromSource @@ -122,8 +123,8 @@ class ConversationServiceTestDataFactory: conversation.is_deleted = kwargs.get("is_deleted", False) conversation.name = kwargs.get("name", "Test Conversation") conversation.status = kwargs.get("status", "normal") - conversation.created_at = kwargs.get("created_at", datetime.utcnow()) - conversation.updated_at = kwargs.get("updated_at", datetime.utcnow()) + conversation.created_at = kwargs.get("created_at", naive_utc_now()) + conversation.updated_at = kwargs.get("updated_at", naive_utc_now()) for key, value in kwargs.items(): setattr(conversation, key, value) return conversation @@ -152,7 +153,7 @@ class ConversationServiceTestDataFactory: message.conversation_id = conversation_id message.app_id = app_id message.query = kwargs.get("query", "Test message content") - message.created_at = kwargs.get("created_at", datetime.utcnow()) + message.created_at = kwargs.get("created_at", naive_utc_now()) for key, value in kwargs.items(): setattr(message, key, value) return message @@ -181,8 +182,8 @@ class ConversationServiceTestDataFactory: variable.conversation_id = conversation_id variable.app_id = app_id variable.data = {"name": kwargs.get("name", "test_var"), "value": kwargs.get("value", "test_value")} - variable.created_at = kwargs.get("created_at", datetime.utcnow()) - variable.updated_at = kwargs.get("updated_at", datetime.utcnow()) + variable.created_at = kwargs.get("created_at", naive_utc_now()) + variable.updated_at = kwargs.get("updated_at", naive_utc_now()) # Mock to_variable method mock_variable = Mock() @@ -302,7 +303,7 @@ class TestConversationServiceHelpers: """ # Arrange mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_conversation.updated_at = datetime.utcnow() + mock_conversation.updated_at = naive_utc_now() # Act condition = ConversationService._build_filter_condition( @@ -323,7 +324,7 @@ class TestConversationServiceHelpers: """ # Arrange mock_conversation = ConversationServiceTestDataFactory.create_conversation_mock() - mock_conversation.created_at = datetime.utcnow() + mock_conversation.created_at = naive_utc_now() # Act condition = ConversationService._build_filter_condition( @@ -668,9 +669,9 @@ class TestConversationServiceConversationalVariable: mock_session_factory.create_session.return_value.__enter__.return_value = mock_session last_variable = ConversationServiceTestDataFactory.create_conversation_variable_mock( - created_at=datetime.utcnow() - timedelta(hours=1) + created_at=naive_utc_now() - timedelta(hours=1) ) - variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=datetime.utcnow()) + variable = ConversationServiceTestDataFactory.create_conversation_variable_mock(created_at=naive_utc_now()) mock_session.scalar.return_value = last_variable mock_session.scalars.return_value.all.return_value = [variable] diff --git a/api/tests/unit_tests/services/test_dataset_service_dataset.py b/api/tests/unit_tests/services/test_dataset_service_dataset.py new file mode 100644 index 0000000000..92aed7c30a --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_dataset.py @@ -0,0 +1,1760 @@ +"""Unit tests for DatasetService and dataset-related collaborators.""" + +from .dataset_service_test_helpers import ( + CloudPlan, + Dataset, + DatasetCollectionBindingService, + DatasetNameDuplicateError, + DatasetPermissionEnum, + DatasetPermissionService, + DatasetProcessRule, + DatasetService, + DatasetServiceUnitDataFactory, + DocumentIndexingError, + DocumentService, + LLMBadRequestError, + MagicMock, + Mock, + ModelFeature, + ModelType, + NoPermissionError, + NotFound, + PipelineIconInfo, + ProviderTokenNotInitError, + RagPipelineDatasetCreateEntity, + SimpleNamespace, + TenantAccountRole, + _make_knowledge_configuration, + _make_retrieval_model, + _make_session_context, + json, + patch, + pytest, +) + + +class TestDatasetServiceQueries: + """Unit tests for DatasetService query composition and fallback branches.""" + + @pytest.fixture + def mock_dataset_query_dependencies(self): + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped-search") as escape_like, + patch("services.dataset_service.TagService.get_target_ids_by_tag_ids") as get_target_ids, + ): + mock_db.paginate.return_value = SimpleNamespace(items=["dataset"], total=1) + yield { + "db": mock_db, + "escape_like_pattern": escape_like, + "get_target_ids": get_target_ids, + } + + def test_get_datasets_returns_paginated_results_for_public_view(self, mock_dataset_query_dependencies): + items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1") + + assert items == ["dataset"] + assert total == 1 + mock_dataset_query_dependencies["db"].paginate.assert_called_once() + mock_dataset_query_dependencies["escape_like_pattern"].assert_not_called() + + def test_get_datasets_short_circuits_for_dataset_operator_without_permissions( + self, mock_dataset_query_dependencies + ): + user = DatasetServiceUnitDataFactory.create_user_mock(role=TenantAccountRole.DATASET_OPERATOR) + mock_dataset_query_dependencies["db"].session.query.return_value.filter_by.return_value.all.return_value = [] + + items, total = DatasetService.get_datasets(page=1, per_page=20, tenant_id="tenant-1", user=user) + + assert items == [] + assert total == 0 + mock_dataset_query_dependencies["db"].paginate.assert_not_called() + + def test_get_datasets_short_circuits_when_tag_lookup_returns_no_target_ids(self, mock_dataset_query_dependencies): + mock_dataset_query_dependencies["get_target_ids"].return_value = [] + + items, total = DatasetService.get_datasets( + page=1, + per_page=20, + tenant_id="tenant-1", + tag_ids=["tag-1"], + ) + + assert items == [] + assert total == 0 + mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"]) + mock_dataset_query_dependencies["db"].paginate.assert_not_called() + + def test_get_datasets_search_and_tag_filters_call_collaborators(self, mock_dataset_query_dependencies): + mock_dataset_query_dependencies["get_target_ids"].return_value = ["dataset-1"] + + items, total = DatasetService.get_datasets( + page=2, + per_page=10, + tenant_id="tenant-1", + search="report", + tag_ids=["tag-1"], + ) + + assert items == ["dataset"] + assert total == 1 + mock_dataset_query_dependencies["escape_like_pattern"].assert_called_once_with("report") + mock_dataset_query_dependencies["get_target_ids"].assert_called_once_with("knowledge", "tenant-1", ["tag-1"]) + mock_dataset_query_dependencies["db"].paginate.assert_called_once() + + def test_get_process_rules_returns_latest_rule_when_present(self): + dataset_process_rule = Mock(spec=DatasetProcessRule) + dataset_process_rule.mode = "automatic" + dataset_process_rule.rules_dict = {"delimiter": "\n"} + + with patch("services.dataset_service.db") as mock_db: + ( + mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value + ) = dataset_process_rule + + result = DatasetService.get_process_rules("dataset-1") + + assert result == {"mode": "automatic", "rules": {"delimiter": "\n"}} + + def test_get_process_rules_falls_back_to_default_rules_when_missing(self): + with patch("services.dataset_service.db") as mock_db: + ( + mock_db.session.query.return_value.where.return_value.order_by.return_value.limit.return_value.one_or_none.return_value + ) = None + + result = DatasetService.get_process_rules("dataset-1") + + assert result == { + "mode": DocumentService.DEFAULT_RULES["mode"], + "rules": DocumentService.DEFAULT_RULES["rules"], + } + + def test_get_datasets_by_ids_returns_empty_for_missing_ids(self): + with patch("services.dataset_service.db") as mock_db: + items, total = DatasetService.get_datasets_by_ids([], "tenant-1") + + assert items == [] + assert total == 0 + mock_db.paginate.assert_not_called() + + def test_get_datasets_by_ids_uses_paginate_for_non_empty_input(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.paginate.return_value = SimpleNamespace(items=["dataset-1"], total=1) + + items, total = DatasetService.get_datasets_by_ids(["dataset-1"], "tenant-1") + + assert items == ["dataset-1"] + assert total == 1 + mock_db.paginate.assert_called_once() + + def test_get_dataset_returns_first_match(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = dataset + + result = DatasetService.get_dataset(dataset.id) + + assert result is dataset + + +class TestDatasetServiceValidation: + """Unit tests for DatasetService validation helpers.""" + + @pytest.mark.parametrize( + ("dataset_doc_form", "incoming_doc_form"), + [(None, "text_model"), ("text_model", "text_model")], + ) + def test_check_doc_form_allows_matching_or_missing_dataset_doc_form(self, dataset_doc_form, incoming_doc_form): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form=dataset_doc_form) + + DatasetService.check_doc_form(dataset, incoming_doc_form) + + def test_check_doc_form_rejects_mismatched_doc_form(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="qa_model") + + with pytest.raises(ValueError, match="doc_form is different"): + DatasetService.check_doc_form(dataset, "text_model") + + def test_check_dataset_model_setting_skips_non_high_quality_datasets(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="economy") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_dataset_model_setting(dataset) + + model_manager_cls.assert_not_called() + + def test_check_dataset_model_setting_validates_embedding_model_for_high_quality_dataset(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_dataset_model_setting(dataset) + + model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with( + tenant_id=dataset.tenant_id, + provider=dataset.embedding_model_provider, + model_type=ModelType.TEXT_EMBEDDING, + model=dataset.embedding_model, + ) + + def test_check_dataset_model_setting_wraps_llm_bad_request_error(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Embedding Model available"): + DatasetService.check_dataset_model_setting(dataset) + + def test_check_dataset_model_setting_wraps_provider_token_error(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(indexing_technique="high_quality") + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + with pytest.raises(ValueError, match="The dataset is unavailable, due to: token missing"): + DatasetService.check_dataset_model_setting(dataset) + + def test_check_embedding_model_setting_wraps_provider_token_error_description(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "provider setup" + ) + + with pytest.raises(ValueError, match="provider setup"): + DatasetService.check_embedding_model_setting("tenant-1", "provider", "embedding-model") + + def test_check_reranking_model_setting_uses_rerank_model_type(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") + + model_manager_cls.for_tenant.return_value.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="provider", + model_type=ModelType.RERANK, + model="reranker", + ) + + def test_check_reranking_model_setting_wraps_bad_request(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Rerank Model available"): + DatasetService.check_reranking_model_setting("tenant-1", "provider", "reranker") + + def test_check_is_multimodal_model_returns_true_when_model_supports_vision(self): + model_schema = SimpleNamespace(features=[ModelFeature.VISION]) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + assert result is True + + def test_check_is_multimodal_model_returns_false_when_vision_feature_is_absent(self): + model_schema = SimpleNamespace(features=[]) + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = model_schema + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + result = DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + assert result is False + + def test_check_is_multimodal_model_raises_when_schema_is_missing(self): + model_type_instance = MagicMock() + model_type_instance.get_model_schema.return_value = None + model_instance = SimpleNamespace( + model_type_instance=model_type_instance, + model_name="embedding-model", + credentials={"api_key": "secret"}, + ) + + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = model_instance + + with pytest.raises(ValueError, match="Model schema not found"): + DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + def test_check_is_multimodal_model_wraps_bad_request_error(self): + with patch("services.dataset_service.ModelManager") as model_manager_cls: + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = LLMBadRequestError() + + with pytest.raises(ValueError, match="No Model available"): + DatasetService.check_is_multimodal_model("tenant-1", "provider", "embedding-model") + + +class TestDatasetServiceCreationAndUpdate: + """Unit tests for dataset creation and update helpers.""" + + def test_create_empty_dataset_raises_when_name_already_exists(self): + account = SimpleNamespace(id="user-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + with pytest.raises(DatasetNameDuplicateError, match="Dataset with name Dataset already exists"): + DatasetService.create_empty_dataset("tenant-1", "Dataset", None, "economy", account) + + def test_create_empty_dataset_uses_default_embedding_model_for_high_quality_dataset(self): + account = SimpleNamespace(id="user-1") + default_embedding_model = SimpleNamespace(provider="provider", model_name="default-embedding") + + with ( + patch("services.dataset_service.db") as mock_db, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs), + ), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = default_embedding_model + + dataset = DatasetService.create_empty_dataset( + tenant_id="tenant-1", + name="Dataset", + description="Description", + indexing_technique="high_quality", + account=account, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "default-embedding" + assert dataset.permission == DatasetPermissionEnum.ONLY_ME + assert dataset.provider == "vendor" + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with( + tenant_id="tenant-1", + model_type=ModelType.TEXT_EMBEDDING, + ) + check_embedding.assert_not_called() + mock_db.session.commit.assert_called_once() + + def test_create_empty_dataset_creates_external_binding_for_high_quality_dataset(self): + account = SimpleNamespace(id="user-1") + retrieval_model = _make_retrieval_model() + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.db") as mock_db, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: SimpleNamespace(id="dataset-1", **kwargs), + ), + patch( + "services.dataset_service.ExternalKnowledgeBindings", + side_effect=lambda **kwargs: SimpleNamespace(**kwargs), + ) as binding_cls, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.ExternalDatasetService.get_external_knowledge_api", return_value=object()), + patch.object(DatasetService, "check_embedding_model_setting") as check_embedding, + patch.object(DatasetService, "check_reranking_model_setting") as check_reranking, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + dataset = DatasetService.create_empty_dataset( + tenant_id="tenant-1", + name="External Dataset", + description="Description", + indexing_technique="high_quality", + account=account, + permission=DatasetPermissionEnum.ALL_TEAM, + provider="external", + external_knowledge_api_id="api-1", + external_knowledge_id="knowledge-1", + embedding_model_provider="provider", + embedding_model_name="embedding-model", + retrieval_model=retrieval_model, + summary_index_setting={"enable": True}, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "embedding-model" + assert dataset.retrieval_model == retrieval_model.model_dump() + assert dataset.summary_index_setting == {"enable": True} + check_embedding.assert_called_once_with("tenant-1", "provider", "embedding-model") + check_reranking.assert_called_once_with("tenant-1", "rerank-provider", "rerank-model") + binding_cls.assert_called_once_with( + tenant_id="tenant-1", + dataset_id="dataset-1", + external_knowledge_api_id="api-1", + external_knowledge_id="knowledge-1", + created_by="user-1", + ) + assert mock_db.session.add.call_count == 2 + mock_db.session.commit.assert_called_once() + + def test_create_empty_rag_pipeline_dataset_raises_for_duplicate_name(self): + entity = RagPipelineDatasetCreateEntity( + name="Existing Dataset", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + with pytest.raises(DatasetNameDuplicateError, match="Existing Dataset already exists"): + DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + def test_create_empty_rag_pipeline_dataset_generates_name_and_creates_dataset(self): + entity = RagPipelineDatasetCreateEntity( + name="", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + pipeline = SimpleNamespace(id="pipeline-1") + + def pipeline_factory(**kwargs): + pipeline.__dict__.update(kwargs) + return pipeline + + def dataset_factory(**kwargs): + return SimpleNamespace(id="dataset-1", **kwargs) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.generate_incremental_name", return_value="Untitled 2") as generate_name, + patch("services.dataset_service.Pipeline", side_effect=pipeline_factory), + patch("services.dataset_service.Dataset", side_effect=dataset_factory), + ): + mock_db.session.query.return_value.filter_by.return_value.all.return_value = [ + SimpleNamespace(name="Untitled"), + SimpleNamespace(name="Untitled 1"), + ] + + dataset = DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + assert entity.name == "Untitled 2" + assert dataset.pipeline_id == "pipeline-1" + assert dataset.runtime_mode == "rag_pipeline" + generate_name.assert_called_once_with(["Untitled", "Untitled 1"], "Untitled") + mock_db.session.commit.assert_called_once() + + def test_create_empty_rag_pipeline_dataset_requires_current_user_id(self): + entity = RagPipelineDatasetCreateEntity( + name="Dataset", + description="Description", + icon_info=PipelineIconInfo(icon="book", icon_background="#fff"), + permission=DatasetPermissionEnum.ALL_TEAM, + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.current_user", SimpleNamespace(id=None)), + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.create_empty_rag_pipeline_dataset("tenant-1", entity) + + def test_update_dataset_raises_when_dataset_is_missing(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(ValueError, match="Dataset not found"): + DatasetService.update_dataset("dataset-1", {}, SimpleNamespace(id="user-1")) + + def test_update_dataset_raises_when_new_name_conflicts(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.name = "Old Dataset" + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "_has_dataset_same_name", return_value=True), + ): + with pytest.raises(ValueError, match="Dataset name already exists"): + DatasetService.update_dataset("dataset-1", {"name": "New Dataset"}, SimpleNamespace(id="user-1")) + + def test_update_dataset_routes_external_datasets_to_external_helper(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.provider = "external" + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch.object(DatasetService, "_update_external_dataset", return_value="updated") as update_external, + ): + result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user) + + assert result == "updated" + check_permission.assert_called_once_with(dataset, user) + update_external.assert_called_once_with(dataset, {"name": dataset.name}, user) + + def test_update_dataset_routes_internal_datasets_to_internal_helper(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1", tenant_id="tenant-1") + dataset.provider = "vendor" + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch.object(DatasetService, "_update_internal_dataset", return_value="updated") as update_internal, + ): + result = DatasetService.update_dataset("dataset-1", {"name": dataset.name}, user) + + assert result == "updated" + check_permission.assert_called_once_with(dataset, user) + update_internal.assert_called_once_with(dataset, {"name": dataset.name}, user) + + def test_has_dataset_same_name_returns_true_when_query_matches(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = object() + + result = DatasetService._has_dataset_same_name("tenant-1", "dataset-1", "Dataset") + + assert result is True + + def test_update_external_dataset_updates_dataset_and_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + user = SimpleNamespace(id="user-1") + now = object() + + with ( + patch.object(DatasetService, "_update_external_knowledge_binding") as update_binding, + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService._update_external_dataset( + dataset, + { + "external_retrieval_model": {"top_k": 3}, + "summary_index_setting": {"enable": True}, + "name": "Updated Dataset", + "description": "Updated description", + "permission": DatasetPermissionEnum.PARTIAL_TEAM, + "external_knowledge_id": "knowledge-1", + "external_knowledge_api_id": "api-1", + }, + user, + ) + + assert result is dataset + assert dataset.retrieval_model == {"top_k": 3} + assert dataset.summary_index_setting == {"enable": True} + assert dataset.name == "Updated Dataset" + assert dataset.description == "Updated description" + assert dataset.permission == DatasetPermissionEnum.PARTIAL_TEAM + assert dataset.updated_by == "user-1" + assert dataset.updated_at is now + update_binding.assert_called_once_with("dataset-1", "knowledge-1", "api-1") + mock_db.session.add.assert_called_once_with(dataset) + mock_db.session.commit.assert_called_once() + + @pytest.mark.parametrize( + ("payload", "message"), + [ + ({"external_knowledge_api_id": "api-1"}, "External knowledge id is required"), + ({"external_knowledge_id": "knowledge-1"}, "External knowledge api id is required"), + ], + ) + def test_update_external_dataset_requires_external_binding_fields(self, payload, message): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + + with pytest.raises(ValueError, match=message): + DatasetService._update_external_dataset(dataset, payload, SimpleNamespace(id="user-1")) + + def test_update_external_knowledge_binding_updates_changed_binding_values(self): + binding = SimpleNamespace(external_knowledge_id="old-knowledge", external_knowledge_api_id="old-api") + session = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = binding + session_context = _make_session_context(session) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.Session", return_value=session_context), + ): + DatasetService._update_external_knowledge_binding("dataset-1", "new-knowledge", "new-api") + + assert binding.external_knowledge_id == "new-knowledge" + assert binding.external_knowledge_api_id == "new-api" + mock_db.session.add.assert_called_once_with(binding) + + def test_update_external_knowledge_binding_raises_for_missing_binding(self): + session = MagicMock() + session.query.return_value.filter_by.return_value.first.return_value = None + session_context = _make_session_context(session) + + with ( + patch("services.dataset_service.db"), + patch("services.dataset_service.Session", return_value=session_context), + ): + with pytest.raises(ValueError, match="External knowledge binding not found"): + DatasetService._update_external_knowledge_binding("dataset-1", "knowledge-1", "api-1") + + def test_update_internal_dataset_updates_fields_and_dispatches_regeneration_tasks(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + user = SimpleNamespace(id="user-1") + now = object() + update_payload = { + "name": "Updated Dataset", + "description": None, + "partial_member_list": [{"user_id": "member-1"}], + "external_knowledge_api_id": "api-1", + "external_knowledge_id": "knowledge-1", + "external_retrieval_model": {"top_k": 2}, + "retrieval_model": {"top_k": 4}, + "summary_index_setting": {"enable": True}, + "icon_info": {"icon": "book"}, + } + + with ( + patch.object(DatasetService, "_handle_indexing_technique_change", return_value="update"), + patch.object(DatasetService, "_update_pipeline_knowledge_base_node_data") as update_pipeline, + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.deal_dataset_vector_index_task") as vector_task, + patch("services.dataset_service.regenerate_summary_index_task") as regenerate_task, + ): + result = DatasetService._update_internal_dataset(dataset, update_payload.copy(), user) + + assert result is dataset + updated_values = mock_db.session.query.return_value.filter_by.return_value.update.call_args.args[0] + assert updated_values["name"] == "Updated Dataset" + assert updated_values["description"] is None + assert updated_values["retrieval_model"] == {"top_k": 4} + assert updated_values["summary_index_setting"] == {"enable": True} + assert updated_values["icon_info"] == {"icon": "book"} + assert updated_values["updated_by"] == "user-1" + assert updated_values["updated_at"] is now + assert "partial_member_list" not in updated_values + assert "external_knowledge_api_id" not in updated_values + assert "external_knowledge_id" not in updated_values + assert "external_retrieval_model" not in updated_values + mock_db.session.commit.assert_called_once() + mock_db.session.refresh.assert_called_once_with(dataset) + update_pipeline.assert_called_once_with(dataset, "user-1") + vector_task.delay.assert_called_once_with("dataset-1", "update") + regenerate_task.delay.assert_called_once_with( + "dataset-1", + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + + def test_update_pipeline_knowledge_base_node_data_returns_early_for_non_pipeline_dataset(self): + dataset = SimpleNamespace(runtime_mode="workflow", pipeline_id="pipeline-1") + + with patch("services.dataset_service.db") as mock_db: + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.query.assert_not_called() + + def test_update_pipeline_knowledge_base_node_data_returns_when_pipeline_is_missing(self): + dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.commit.assert_not_called() + + def test_update_pipeline_knowledge_base_node_data_updates_published_and_draft_workflows(self): + dataset = SimpleNamespace( + id="dataset-1", + runtime_mode="rag_pipeline", + pipeline_id="pipeline-1", + embedding_model="embedding-model", + embedding_model_provider="provider", + retrieval_model={"top_k": 5}, + chunk_structure="paragraph", + indexing_technique="high_quality", + keyword_number=8, + summary_index_setting={"enable": True}, + ) + pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1") + published_workflow = SimpleNamespace( + graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}, {"data": {"type": "start"}}]}), + type="chat", + features={"feature": True}, + environment_variables=[], + conversation_variables=[], + rag_pipeline_variables=[], + ) + draft_workflow = SimpleNamespace(graph=json.dumps({"nodes": [{"data": {"type": "knowledge-index"}}]})) + new_workflow = SimpleNamespace(id="workflow-1") + rag_pipeline_service = MagicMock() + rag_pipeline_service.get_published_workflow.return_value = published_workflow + rag_pipeline_service.get_draft_workflow.return_value = draft_workflow + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service), + patch("services.dataset_service.Workflow.new", return_value=new_workflow) as workflow_new, + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline + + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + published_graph = json.loads(workflow_new.call_args.kwargs["graph"]) + assert published_graph["nodes"][0]["data"]["embedding_model"] == "embedding-model" + assert published_graph["nodes"][0]["data"]["summary_index_setting"] == {"enable": True} + assert json.loads(draft_workflow.graph)["nodes"][0]["data"]["embedding_model_provider"] == "provider" + mock_db.session.add.assert_any_call(new_workflow) + mock_db.session.add.assert_any_call(draft_workflow) + mock_db.session.commit.assert_called_once() + + def test_update_pipeline_knowledge_base_node_data_rolls_back_when_update_fails(self): + dataset = SimpleNamespace(runtime_mode="rag_pipeline", pipeline_id="pipeline-1") + pipeline = SimpleNamespace(id="pipeline-1", tenant_id="tenant-1") + rag_pipeline_service = MagicMock() + rag_pipeline_service.get_published_workflow.side_effect = RuntimeError("boom") + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.RagPipelineService", return_value=rag_pipeline_service), + ): + mock_db.session.query.return_value.filter_by.return_value.first.return_value = pipeline + + with pytest.raises(RuntimeError, match="boom"): + DatasetService._update_pipeline_knowledge_base_node_data(dataset, "user-1") + + mock_db.session.rollback.assert_called_once() + + def test_handle_indexing_technique_change_returns_none_without_indexing_technique(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="economy") + + result = DatasetService._handle_indexing_technique_change(dataset, {}, filtered_data) + + assert result is None + assert filtered_data == {} + + def test_handle_indexing_technique_change_switches_to_economy(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="high_quality") + + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "economy"}, + filtered_data, + ) + + assert result == "remove" + assert filtered_data == { + "embedding_model": None, + "embedding_model_provider": None, + "collection_binding_id": None, + } + + def test_handle_indexing_technique_change_switches_to_high_quality(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="economy") + + with patch.object(DatasetService, "_configure_embedding_model_for_high_quality") as configure_embedding: + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "high_quality"}, + filtered_data, + ) + + assert result == "add" + configure_embedding.assert_called_once_with({"indexing_technique": "high_quality"}, filtered_data) + + def test_handle_indexing_technique_change_delegates_when_technique_is_unchanged(self): + filtered_data: dict[str, object] = {} + dataset = SimpleNamespace(indexing_technique="high_quality") + + with patch.object( + DatasetService, + "_handle_embedding_model_update_when_technique_unchanged", + return_value="update", + ) as update_embedding: + result = DatasetService._handle_indexing_technique_change( + dataset, + {"indexing_technique": "high_quality"}, + filtered_data, + ) + + assert result == "update" + update_embedding.assert_called_once_with(dataset, {"indexing_technique": "high_quality"}, filtered_data) + + def test_configure_embedding_model_for_high_quality_updates_filtered_data(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService._configure_embedding_model_for_high_quality( + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model": "embedding-model", + "embedding_model_provider": "provider", + "collection_binding_id": "binding-1", + } + + @pytest.mark.parametrize( + ("error", "message"), + [ + (LLMBadRequestError(), "No Embedding Model available"), + (ProviderTokenNotInitError("provider setup"), "provider setup"), + ], + ) + def test_configure_embedding_model_for_high_quality_wraps_model_errors(self, error, message): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = error + + with pytest.raises(ValueError, match=message): + DatasetService._configure_embedding_model_for_high_quality( + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + {}, + ) + + def test_handle_embedding_model_update_when_technique_unchanged_preserves_existing_settings(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + filtered_data: dict[str, object] = {} + + with patch.object(DatasetService, "_preserve_existing_embedding_settings") as preserve_settings: + result = DatasetService._handle_embedding_model_update_when_technique_unchanged( + dataset, + {}, + filtered_data, + ) + + assert result is None + preserve_settings.assert_called_once_with(dataset, filtered_data) + + def test_handle_embedding_model_update_when_technique_unchanged_updates_when_model_is_provided(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_update_embedding_model_settings", return_value="update") as update_settings: + result = DatasetService._handle_embedding_model_update_when_technique_unchanged( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + assert result == "update" + update_settings.assert_called_once() + + def test_preserve_existing_embedding_settings_keeps_current_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + collection_binding_id="binding-1", + ) + filtered_data = {"embedding_model_provider": "", "embedding_model": ""} + + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + + assert filtered_data == { + "embedding_model_provider": "provider", + "embedding_model": "embedding-model", + "collection_binding_id": "binding-1", + } + + def test_preserve_existing_embedding_settings_removes_empty_placeholders_without_existing_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider=None, + embedding_model=None, + collection_binding_id=None, + ) + filtered_data = {"embedding_model_provider": "", "embedding_model": ""} + + DatasetService._preserve_existing_embedding_settings(dataset, filtered_data) + + assert filtered_data == {} + + def test_update_embedding_model_settings_returns_update_for_changed_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_apply_new_embedding_settings") as apply_settings: + result = DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + assert result == "update" + apply_settings.assert_called_once() + + def test_update_embedding_model_settings_returns_none_for_unchanged_values(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + result = DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider", "embedding_model": "embedding-model"}, + {}, + ) + + assert result is None + + def test_update_embedding_model_settings_wraps_bad_request_errors(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + ) + + with patch.object(DatasetService, "_apply_new_embedding_settings", side_effect=LLMBadRequestError()): + with pytest.raises(ValueError, match="No Embedding Model available"): + DatasetService._update_embedding_model_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + {}, + ) + + def test_apply_new_embedding_settings_updates_binding_for_new_model(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(collection_binding_id="binding-1") + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace( + provider="provider-two", + model_name="embedding-model-two", + ) + + DatasetService._apply_new_embedding_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model": "embedding-model-two", + "embedding_model_provider": "provider-two", + "collection_binding_id": "binding-2", + } + + def test_apply_new_embedding_settings_preserves_existing_values_when_provider_token_is_missing(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + embedding_model_provider="provider", + embedding_model="embedding-model", + collection_binding_id="binding-1", + ) + filtered_data: dict[str, object] = {} + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.ModelManager") as model_manager_cls, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + DatasetService._apply_new_embedding_settings( + dataset, + {"embedding_model_provider": "provider-two", "embedding_model": "embedding-model-two"}, + filtered_data, + ) + + assert filtered_data == { + "embedding_model_provider": "provider", + "embedding_model": "embedding-model", + "collection_binding_id": "binding-1", + } + + @pytest.mark.parametrize( + ("summary_index_setting", "expected"), + [ + (None, False), + ({"enable": False}, False), + ({"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, False), + ({"enable": True, "model_name": "new-model", "model_provider_name": "provider-two"}, True), + ], + ) + def test_check_summary_index_setting_model_changed(self, summary_index_setting, expected): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", + summary_index_setting={"enable": True, "model_name": "old-model", "model_provider_name": "provider"}, + ) + + result = DatasetService._check_summary_index_setting_model_changed( + dataset, + {"summary_index_setting": summary_index_setting} if summary_index_setting is not None else {}, + ) + + assert result is expected + + +class TestDatasetServiceRagPipelineSettings: + """Unit tests for rag-pipeline dataset setting updates.""" + + def test_update_rag_pipeline_dataset_settings_requires_current_tenant(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + knowledge_configuration = _make_knowledge_configuration() + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id=None)): + with pytest.raises(ValueError, match="Current user or current tenant not found"): + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + def test_update_rag_pipeline_dataset_settings_without_published_high_quality_updates_embedding_settings(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration(summary_index_setting={"enable": True}) + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=True) as check_multimodal, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + assert dataset.chunk_structure == "paragraph" + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "embedding-model" + assert dataset.embedding_model_provider == "provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.is_multimodal is True + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + assert dataset.summary_index_setting == {"enable": True} + check_multimodal.assert_called_once_with("tenant-1", "provider", "embedding-model") + session.add.assert_called_once_with(dataset) + session.commit.assert_not_called() + + def test_update_rag_pipeline_dataset_settings_without_published_economy_updates_keyword_number(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + keyword_number=12, + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + DatasetService.update_rag_pipeline_dataset_settings(session, dataset, knowledge_configuration) + + assert dataset.indexing_technique == "economy" + assert dataset.keyword_number == 12 + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + + def test_update_rag_pipeline_dataset_settings_with_published_rejects_chunk_structure_changes(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration(chunk_structure="sentence") + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + with pytest.raises(ValueError, match="Chunk structure is not allowed to be updated"): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + def test_update_rag_pipeline_dataset_settings_with_published_rejects_switch_to_economy(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + ) + + with patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")): + with pytest.raises( + ValueError, + match="Knowledge base indexing technique is not allowed to be updated to economy", + ): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + def test_update_rag_pipeline_dataset_settings_with_published_adds_high_quality_index(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "economy" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration() + embedding_model = SimpleNamespace(provider="provider", model_name="embedding-model") + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=False), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "embedding-model" + assert dataset.embedding_model_provider == "provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.is_multimodal is False + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "add") + + def test_update_rag_pipeline_dataset_settings_with_published_updates_changed_embedding_model(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + embedding_model_provider="provider-two", + embedding_model="embedding-model-two", + summary_index_setting={"enable": True}, + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch.object(DatasetService, "check_is_multimodal_model", return_value=True), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = SimpleNamespace( + provider="provider-two", + model_name="embedding-model-two", + ) + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.embedding_model_provider == "provider-two" + assert dataset.embedding_model == "embedding-model-two" + assert dataset.collection_binding_id == "binding-2" + assert dataset.is_multimodal is True + assert dataset.summary_index_setting == {"enable": True} + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "update") + + def test_update_rag_pipeline_dataset_settings_with_published_skips_embedding_update_when_token_is_missing(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "high_quality" + dataset.embedding_model_provider = "provider" + dataset.embedding_model = "embedding-model" + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + embedding_model_provider="provider-two", + embedding_model="embedding-model-two", + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + model_manager_cls.for_tenant.return_value.get_model_instance.side_effect = ProviderTokenNotInitError( + "token missing" + ) + + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.embedding_model_provider == "provider" + assert dataset.embedding_model == "embedding-model" + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "update") + + def test_update_rag_pipeline_dataset_settings_with_published_updates_economy_keyword_number(self): + session = MagicMock() + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(dataset_id="dataset-1") + dataset.chunk_structure = "paragraph" + dataset.indexing_technique = "economy" + dataset.keyword_number = 5 + session.merge.return_value = dataset + knowledge_configuration = _make_knowledge_configuration( + indexing_technique="economy", + embedding_model_provider="", + embedding_model="", + keyword_number=9, + ) + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(current_tenant_id="tenant-1")), + patch("services.dataset_service.deal_dataset_index_update_task") as update_task, + ): + DatasetService.update_rag_pipeline_dataset_settings( + session, + dataset, + knowledge_configuration, + has_published=True, + ) + + assert dataset.keyword_number == 9 + assert dataset.retrieval_model == knowledge_configuration.retrieval_model.model_dump() + session.add.assert_called_once_with(dataset) + session.commit.assert_called_once() + update_task.delay.assert_not_called() + + +class TestDatasetServicePermissionsAndLifecycle: + """Unit tests for dataset permissions, deletion, and metadata helpers.""" + + def test_delete_dataset_returns_false_when_dataset_is_missing(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + result = DatasetService.delete_dataset("dataset-1", user=SimpleNamespace(id="user-1")) + + assert result is False + + def test_delete_dataset_checks_permission_and_deletes_dataset(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission") as check_permission, + patch("services.dataset_service.dataset_was_deleted.send") as send_deleted_signal, + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService.delete_dataset(dataset.id, user=SimpleNamespace(id="user-1")) + + assert result is True + check_permission.assert_called_once_with(dataset, SimpleNamespace(id="user-1")) + send_deleted_signal.assert_called_once_with(dataset) + mock_db.session.delete.assert_called_once_with(dataset) + mock_db.session.commit.assert_called_once() + + def test_dataset_use_check_returns_scalar_result(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.execute.return_value.scalar_one.return_value = True + + result = DatasetService.dataset_use_check("dataset-1") + + assert result is True + + def test_check_dataset_permission_rejects_cross_tenant_access(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(tenant_id="tenant-a") + user = DatasetServiceUnitDataFactory.create_user_mock(tenant_id="tenant-b") + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_rejects_only_me_dataset_for_non_creator(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.ONLY_ME, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_rejects_partial_team_user_without_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = None + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_permission_allows_partial_team_creator_without_lookup(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="creator-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="creator-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + DatasetService.check_dataset_permission(dataset, user) + + mock_db.session.query.assert_not_called() + + def test_check_dataset_permission_allows_partial_team_member_with_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.PARTIAL_TEAM, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() + + DatasetService.check_dataset_permission(dataset, user) + + def test_check_dataset_operator_permission_validates_required_arguments(self): + with pytest.raises(ValueError, match="Dataset not found"): + DatasetService.check_dataset_operator_permission(user=SimpleNamespace(id="user-1"), dataset=None) + + with pytest.raises(ValueError, match="User not found"): + DatasetService.check_dataset_operator_permission(user=None, dataset=SimpleNamespace(id="dataset-1")) + + def test_check_dataset_operator_permission_rejects_only_me_for_non_creator(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + permission=DatasetPermissionEnum.ONLY_ME, + created_by="owner-1", + ) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + def test_check_dataset_operator_permission_rejects_partial_team_without_binding(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission=DatasetPermissionEnum.PARTIAL_TEAM) + user = DatasetServiceUnitDataFactory.create_user_mock( + user_id="member-1", + role=TenantAccountRole.EDITOR, + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.all.return_value = [] + + with pytest.raises(NoPermissionError, match="do not have permission"): + DatasetService.check_dataset_operator_permission(user=user, dataset=dataset) + + def test_get_dataset_queries_delegates_to_paginate(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.desc.side_effect = lambda column: column + mock_db.paginate.return_value = SimpleNamespace(items=["query"], total=1) + + items, total = DatasetService.get_dataset_queries("dataset-1", page=1, per_page=20) + + assert items == ["query"] + assert total == 1 + mock_db.paginate.assert_called_once() + + def test_get_related_apps_returns_ordered_query_results(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.desc.side_effect = lambda column: column + mock_db.session.query.return_value.where.return_value.order_by.return_value.all.return_value = [ + "relation-1" + ] + + result = DatasetService.get_related_apps("dataset-1") + + assert result == ["relation-1"] + + def test_update_dataset_api_status_raises_not_found_for_missing_dataset(self): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(NotFound, match="Dataset not found"): + DatasetService.update_dataset_api_status("dataset-1", True) + + def test_update_dataset_api_status_requires_current_user_id(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False) + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch("services.dataset_service.current_user", SimpleNamespace(id=None)), + ): + with pytest.raises(ValueError, match="Current user or current user id not found"): + DatasetService.update_dataset_api_status(dataset.id, True) + + def test_update_dataset_api_status_updates_fields_and_commits(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(enable_api=False) + now = object() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.naive_utc_now", return_value=now), + patch("services.dataset_service.db") as mock_db, + ): + DatasetService.update_dataset_api_status(dataset.id, True) + + assert dataset.enable_api is True + assert dataset.updated_by == "user-1" + assert dataset.updated_at is now + mock_db.session.commit.assert_called_once() + + def test_get_dataset_auto_disable_logs_returns_empty_when_billing_is_disabled(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + + features = SimpleNamespace( + billing=SimpleNamespace(enabled=False, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL)) + ) + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.db") as mock_db, + ): + result = DatasetService.get_dataset_auto_disable_logs("dataset-1") + + assert result == {"document_ids": [], "count": 0} + mock_db.session.scalars.assert_not_called() + + def test_get_dataset_auto_disable_logs_returns_recent_document_ids(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + logs = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] + features = SimpleNamespace( + billing=SimpleNamespace(enabled=True, subscription=SimpleNamespace(plan=CloudPlan.PROFESSIONAL)) + ) + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.scalars.return_value.all.return_value = logs + + result = DatasetService.get_dataset_auto_disable_logs("dataset-1") + + assert result == {"document_ids": ["doc-1", "doc-2"], "count": 2} + + +class TestDatasetServiceDocumentIndexing: + """Unit tests for pause/recover/retry orchestration without SQL assertions.""" + + @pytest.fixture + def mock_document_service_dependencies(self): + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db.session") as mock_db_session, + patch("services.dataset_service.current_user") as mock_current_user, + ): + mock_current_user.id = "user-123" + yield { + "redis_client": mock_redis, + "db_session": mock_db_session, + "current_user": mock_current_user, + } + + def test_pause_document_success(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing") + + DocumentService.pause_document(document) + + assert document.is_paused is True + assert document.paused_by == "user-123" + mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) + mock_document_service_dependencies["db_session"].commit.assert_called_once() + mock_document_service_dependencies["redis_client"].setnx.assert_called_once_with( + f"document_{document.id}_is_paused", + "True", + ) + + def test_pause_document_invalid_status_error(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed") + + with pytest.raises(DocumentIndexingError): + DocumentService.pause_document(document) + + def test_recover_document_success(self, mock_document_service_dependencies): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="indexing", is_paused=True) + + with patch("services.dataset_service.recover_document_indexing_task") as recover_task: + DocumentService.recover_document(document) + + assert document.is_paused is False + assert document.paused_by is None + assert document.paused_at is None + mock_document_service_dependencies["db_session"].add.assert_called_once_with(document) + mock_document_service_dependencies["db_session"].commit.assert_called_once() + mock_document_service_dependencies["redis_client"].delete.assert_called_once_with( + f"document_{document.id}_is_paused" + ) + recover_task.delay.assert_called_once_with(document.dataset_id, document.id) + + def test_retry_document_indexing_success(self, mock_document_service_dependencies): + dataset_id = "dataset-123" + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", indexing_status="error"), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", indexing_status="error"), + ] + mock_document_service_dependencies["redis_client"].get.return_value = None + + with patch("services.dataset_service.retry_document_indexing_task") as retry_task: + DocumentService.retry_document(dataset_id, documents) + + assert all(document.indexing_status == "waiting" for document in documents) + assert mock_document_service_dependencies["db_session"].add.call_count == 2 + assert mock_document_service_dependencies["db_session"].commit.call_count == 2 + assert mock_document_service_dependencies["redis_client"].setex.call_count == 2 + retry_task.delay.assert_called_once_with(dataset_id, ["doc-1", "doc-2"], "user-123") + + +class TestDatasetCollectionBindingService: + """Unit tests for dataset collection binding lookups and creation.""" + + def test_get_dataset_collection_binding_returns_existing_binding(self): + binding = SimpleNamespace(id="binding-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model") + + assert result is binding + mock_db.session.add.assert_not_called() + + def test_get_dataset_collection_binding_creates_binding_when_missing(self): + created_binding = SimpleNamespace(id="binding-2") + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetCollectionBinding", return_value=created_binding) as binding_cls, + patch.object(Dataset, "gen_collection_name_by_id", return_value="generated-collection"), + ): + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + result = DatasetCollectionBindingService.get_dataset_collection_binding("provider", "model", "dataset") + + assert result is created_binding + binding_cls.assert_called_once_with( + provider_name="provider", + model_name="model", + collection_name="generated-collection", + type="dataset", + ) + mock_db.session.add.assert_called_once_with(created_binding) + mock_db.session.commit.assert_called_once() + + def test_get_dataset_collection_binding_by_id_and_type_raises_when_missing(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = None + + with pytest.raises(ValueError, match="Dataset collection binding not found"): + DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") + + def test_get_dataset_collection_binding_by_id_and_type_returns_binding(self): + binding = SimpleNamespace(id="binding-1") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = binding + + result = DatasetCollectionBindingService.get_dataset_collection_binding_by_id_and_type("binding-1") + + assert result is binding + + +class TestDatasetPermissionService: + """Unit tests for dataset partial-member management helpers.""" + + def test_get_dataset_partial_member_list_returns_scalar_results(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = ["user-1", "user-2"] + + result = DatasetPermissionService.get_dataset_partial_member_list("dataset-1") + + assert result == ["user-1", "user-2"] + + def test_update_partial_member_list_replaces_permissions_and_commits(self): + with patch("services.dataset_service.db") as mock_db: + DatasetPermissionService.update_partial_member_list( + "tenant-1", + "dataset-1", + [{"user_id": "user-1"}, {"user_id": "user-2"}], + ) + + mock_db.session.query.return_value.where.return_value.delete.assert_called_once() + mock_db.session.add_all.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_update_partial_member_list_rolls_back_on_exception(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.add_all.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + DatasetPermissionService.update_partial_member_list( + "tenant-1", + "dataset-1", + [{"user_id": "user-1"}], + ) + + mock_db.session.rollback.assert_called_once() + + def test_check_permission_requires_dataset_editor(self): + user = SimpleNamespace(is_dataset_editor=False, is_dataset_operator=False) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with pytest.raises(NoPermissionError, match="does not have permission"): + DatasetPermissionService.check_permission(user, dataset, "all_team", []) + + def test_check_permission_prevents_dataset_operator_from_changing_permission_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="all_team") + + with pytest.raises(NoPermissionError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission(user, dataset, "only_me", []) + + def test_check_permission_requires_partial_member_list_for_partial_members_mode(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(permission="partial_members") + + with pytest.raises(ValueError, match="Partial member list is required"): + DatasetPermissionService.check_permission(user, dataset, "partial_members", []) + + def test_check_permission_rejects_dataset_operator_member_list_changes(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", permission="partial_members" + ) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + with pytest.raises(ValueError, match="cannot change the dataset permissions"): + DatasetPermissionService.check_permission( + user, + dataset, + "partial_members", + [{"user_id": "user-2"}], + ) + + def test_check_permission_allows_dataset_operator_when_member_list_is_unchanged(self): + user = SimpleNamespace(is_dataset_editor=True, is_dataset_operator=True) + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", permission="partial_members" + ) + + with patch.object(DatasetPermissionService, "get_dataset_partial_member_list", return_value=["user-1"]): + DatasetPermissionService.check_permission( + user, + dataset, + "partial_members", + [{"user_id": "user-1"}], + ) + + def test_clear_partial_member_list_deletes_permissions_and_commits(self): + with patch("services.dataset_service.db") as mock_db: + DatasetPermissionService.clear_partial_member_list("dataset-1") + + mock_db.session.query.return_value.where.return_value.delete.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_clear_partial_member_list_rolls_back_on_exception(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.delete.side_effect = RuntimeError("boom") + + with pytest.raises(RuntimeError, match="boom"): + DatasetPermissionService.clear_partial_member_list("dataset-1") + + mock_db.session.rollback.assert_called_once() diff --git a/api/tests/unit_tests/services/test_dataset_service_document.py b/api/tests/unit_tests/services/test_dataset_service_document.py new file mode 100644 index 0000000000..c8036487ab --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_document.py @@ -0,0 +1,2078 @@ +"""Unit tests for DocumentService behaviors in dataset_service.""" + +from .dataset_service_test_helpers import ( + Account, + BuiltInField, + CloudPlan, + DatasetProcessRule, + DatasetService, + DatasetServiceUnitDataFactory, + DataSource, + DocumentIndexingError, + DocumentService, + FileInfo, + FileNotExistsError, + Forbidden, + IndexStructureType, + InfoList, + KnowledgeConfig, + MagicMock, + NoPermissionError, + NotFound, + NotionIcon, + NotionInfo, + NotionPage, + PreProcessingRule, + ProcessRule, + RerankingModel, + RetrievalMethod, + RetrievalModel, + Rule, + Segmentation, + SimpleNamespace, + WebsiteInfo, + _make_dataset, + _make_document, + _make_features, + _make_lock_context, + _make_session_context, + _make_upload_knowledge_config, + create_autospec, + json, + patch, + pytest, +) + + +class TestDocumentServiceDisplayStatus: + """Unit tests for DocumentService display-status helpers.""" + + @pytest.mark.parametrize( + ("raw_status", "expected"), + [ + ("enabled", "available"), + ("AVAILABLE", "available"), + ("paused", "paused"), + ("unknown", None), + (None, None), + ], + ) + def test_normalize_display_status(self, raw_status, expected): + assert DocumentService.normalize_display_status(raw_status) == expected + + def test_build_display_status_filters_returns_empty_tuple_for_unknown_status(self): + assert DocumentService.build_display_status_filters("missing") == () + + def test_apply_display_status_filter_returns_original_query_for_unknown_status(self): + query = MagicMock() + + result = DocumentService.apply_display_status_filter(query, "missing") + + assert result is query + query.where.assert_not_called() + + def test_apply_display_status_filter_applies_where_for_known_status(self): + query = MagicMock() + filtered_query = MagicMock() + query.where.return_value = filtered_query + + result = DocumentService.apply_display_status_filter(query, "enabled") + + assert result is filtered_query + query.where.assert_called_once() + + +class TestDocumentServiceQueryAndDownloadHelpers: + """Unit tests for DocumentService query helpers and download flows.""" + + def test_get_document_returns_none_when_document_id_is_missing(self): + with patch("services.dataset_service.db") as mock_db: + result = DocumentService.get_document("dataset-1", None) + + assert result is None + mock_db.session.query.assert_not_called() + + def test_get_document_queries_by_dataset_and_document_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = document + + result = DocumentService.get_document("dataset-1", "doc-1") + + assert result is document + + def test_get_documents_by_ids_returns_empty_for_empty_input(self): + with patch("services.dataset_service.db") as mock_db: + result = DocumentService.get_documents_by_ids("dataset-1", []) + + assert result == [] + mock_db.session.scalars.assert_not_called() + + def test_get_documents_by_ids_uses_single_batch_query(self): + document = DatasetServiceUnitDataFactory.create_document_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_documents_by_ids("dataset-1", ["doc-1"]) + + assert result == [document] + mock_db.session.scalars.assert_called_once() + + def test_update_documents_need_summary_returns_zero_for_empty_input(self): + with patch("services.dataset_service.session_factory") as session_factory_mock: + result = DocumentService.update_documents_need_summary("dataset-1", []) + + assert result == 0 + session_factory_mock.create_session.assert_not_called() + + def test_update_documents_need_summary_updates_matching_documents_and_commits(self): + session = MagicMock() + session.query.return_value.filter.return_value.update.return_value = 2 + + with patch("services.dataset_service.session_factory") as session_factory_mock: + session_factory_mock.create_session.return_value = _make_session_context(session) + + result = DocumentService.update_documents_need_summary( + "dataset-1", + ["doc-1", "doc-2"], + need_summary=False, + ) + + assert result == 2 + session.commit.assert_called_once() + + def test_get_document_download_url_uses_upload_file_lookup_and_signed_url_helper(self): + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + document = DatasetServiceUnitDataFactory.create_document_mock() + + with ( + patch.object(DocumentService, "_get_upload_file_for_upload_file_document", return_value=upload_file), + patch("services.dataset_service.file_helpers.get_signed_file_url", return_value="signed-url") as get_url, + ): + result = DocumentService.get_document_download_url(document) + + assert result == "signed-url" + get_url.assert_called_once_with(upload_file_id="file-1", as_attachment=True) + + def test_get_upload_file_id_for_upload_file_document_rejects_invalid_source_type(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_type="not-upload-file") + + with pytest.raises(NotFound, match="invalid source"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + def test_get_upload_file_id_for_upload_file_document_rejects_missing_upload_file_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_info_dict={}) + + with pytest.raises(NotFound, match="missing file"): + DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + def test_get_upload_file_id_for_upload_file_document_returns_string_id(self): + document = DatasetServiceUnitDataFactory.create_document_mock(data_source_info_dict={"upload_file_id": 99}) + + result = DocumentService._get_upload_file_id_for_upload_file_document( + document, + invalid_source_message="invalid source", + missing_file_message="missing file", + ) + + assert result == "99" + + def test_get_upload_file_for_upload_file_document_raises_when_file_service_returns_nothing(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}): + with pytest.raises(NotFound, match="Uploaded file not found"): + DocumentService._get_upload_file_for_upload_file_document(document) + + def test_get_upload_file_for_upload_file_document_returns_upload_file(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + + with patch( + "services.dataset_service.FileService.get_upload_files_by_ids", return_value={"file-1": upload_file} + ): + result = DocumentService._get_upload_file_for_upload_file_document(document) + + assert result is upload_file + + def test_enrich_documents_with_summary_index_status_skips_lookup_when_summary_is_disabled(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(summary_index_setting={"enable": False}) + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", need_summary=False), + ] + + DocumentService.enrich_documents_with_summary_index_status(documents, dataset, tenant_id="tenant-1") + + assert documents[0].summary_index_status is None + assert documents[1].summary_index_status is None + + def test_enrich_documents_with_summary_index_status_applies_summary_status_map(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + dataset_id="dataset-1", + summary_index_setting={"enable": True}, + ) + documents = [ + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-2", need_summary=True), + DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-3", need_summary=False), + ] + + with patch( + "services.summary_index_service.SummaryIndexService.get_documents_summary_index_status", + return_value={"doc-1": "completed", "doc-2": None}, + ) as get_status_map: + DocumentService.enrich_documents_with_summary_index_status(documents, dataset, tenant_id="tenant-1") + + get_status_map.assert_called_once_with( + document_ids=["doc-1", "doc-2"], + dataset_id="dataset-1", + tenant_id="tenant-1", + ) + assert documents[0].summary_index_status == "completed" + assert documents[1].summary_index_status is None + assert documents[2].summary_index_status is None + + def test_generate_document_batch_download_zip_filename_uses_zip_extension(self): + fake_uuid = SimpleNamespace(hex="archive-id") + + with patch("services.dataset_service.uuid.uuid4", return_value=fake_uuid): + result = DocumentService._generate_document_batch_download_zip_filename() + + assert result == "archive-id.zip" + + def test_get_upload_files_by_document_id_for_zip_download_raises_for_missing_documents(self): + with patch.object(DocumentService, "get_documents_by_ids", return_value=[]): + with pytest.raises(NotFound, match="Document not found"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_rejects_cross_tenant_access(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-other", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with patch.object(DocumentService, "get_documents_by_ids", return_value=[document]): + with pytest.raises(Forbidden, match="No permission"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_rejects_missing_upload_files(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with ( + patch.object(DocumentService, "get_documents_by_ids", return_value=[document]), + patch("services.dataset_service.FileService.get_upload_files_by_ids", return_value={}), + ): + with pytest.raises(NotFound, match="Only uploaded-file documents can be downloaded as ZIP"): + DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + ) + + def test_get_upload_files_by_document_id_for_zip_download_returns_document_keyed_mapping(self): + document_a = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + document_b = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-2", + tenant_id="tenant-1", + data_source_info_dict={"upload_file_id": "file-2"}, + ) + upload_file_a = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-1") + upload_file_b = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-2") + + with ( + patch.object(DocumentService, "get_documents_by_ids", return_value=[document_a, document_b]), + patch( + "services.dataset_service.FileService.get_upload_files_by_ids", + return_value={"file-1": upload_file_a, "file-2": upload_file_b}, + ), + ): + result = DocumentService._get_upload_files_by_document_id_for_zip_download( + dataset_id="dataset-1", + document_ids=["doc-1", "doc-2"], + tenant_id="tenant-1", + ) + + assert result == {"doc-1": upload_file_a, "doc-2": upload_file_b} + + def test_prepare_document_batch_download_zip_raises_not_found_for_missing_dataset(self): + user = DatasetServiceUnitDataFactory.create_user_mock() + + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(NotFound, match="Dataset not found"): + DocumentService.prepare_document_batch_download_zip( + dataset_id="dataset-1", + document_ids=["doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + def test_prepare_document_batch_download_zip_translates_permission_error_to_forbidden(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + user = DatasetServiceUnitDataFactory.create_user_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission", side_effect=NoPermissionError("blocked")), + ): + with pytest.raises(Forbidden, match="blocked"): + DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=["doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + def test_prepare_document_batch_download_zip_returns_upload_files_in_requested_order(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + user = DatasetServiceUnitDataFactory.create_user_mock() + upload_file_a = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-a") + upload_file_b = DatasetServiceUnitDataFactory.create_upload_file_mock(file_id="file-b") + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DatasetService, "check_dataset_permission"), + patch.object( + DocumentService, + "_get_upload_files_by_document_id_for_zip_download", + return_value={"doc-1": upload_file_a, "doc-2": upload_file_b}, + ), + patch.object(DocumentService, "_generate_document_batch_download_zip_filename", return_value="archive.zip"), + ): + upload_files, download_name = DocumentService.prepare_document_batch_download_zip( + dataset_id=dataset.id, + document_ids=["doc-2", "doc-1"], + tenant_id="tenant-1", + current_user=user, + ) + + assert upload_files == [upload_file_b, upload_file_a] + assert download_name == "archive.zip" + + def test_get_document_by_dataset_id_returns_enabled_documents(self): + document = DatasetServiceUnitDataFactory.create_document_mock(enabled=True) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_document_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_working_documents_by_dataset_id_returns_scalars_result(self): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="completed", archived=False) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_working_documents_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_error_documents_by_dataset_id_returns_scalars_result(self): + document = DatasetServiceUnitDataFactory.create_document_mock(indexing_status="error") + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_error_documents_by_dataset_id("dataset-1") + + assert result == [document] + + def test_get_batch_documents_filters_by_current_user_tenant(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.current_tenant_id = "tenant-1" + document = DatasetServiceUnitDataFactory.create_document_mock() + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.scalars.return_value.all.return_value = [document] + + result = DocumentService.get_batch_documents("dataset-1", "batch-1") + + assert result == [document] + + def test_get_document_file_detail_returns_one_or_none(self): + upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.one_or_none.return_value = upload_file + + result = DocumentService.get_document_file_detail(upload_file.id) + + assert result is upload_file + + +class TestDocumentServiceMutations: + """Unit tests for DocumentService mutation and orchestration helpers.""" + + @pytest.fixture + def rename_account_context(self): + class FakeAccount: + pass + + current_user = FakeAccount() + current_user.id = "user-123" + current_user.current_tenant_id = "tenant-123" + + with ( + patch("services.dataset_service.Account", FakeAccount), + patch("services.dataset_service.current_user", current_user), + ): + yield current_user + + @pytest.mark.parametrize(("archived", "expected"), [(True, True), (False, False)]) + def test_check_archived_returns_boolean_status(self, archived, expected): + document = DatasetServiceUnitDataFactory.create_document_mock(archived=archived) + + assert DocumentService.check_archived(document) is expected + + def test_delete_document_emits_signal_and_commits(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + data_source_type="upload_file", + data_source_info='{"upload_file_id": "file-1"}', + data_source_info_dict={"upload_file_id": "file-1"}, + ) + + with ( + patch("services.dataset_service.document_was_deleted.send") as send_deleted_signal, + patch("services.dataset_service.db") as mock_db, + ): + DocumentService.delete_document(document) + + send_deleted_signal.assert_called_once_with( + document.id, + dataset_id=document.dataset_id, + doc_form=document.doc_form, + file_id="file-1", + ) + mock_db.session.delete.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + + def test_delete_documents_ignores_empty_input(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with patch("services.dataset_service.db") as mock_db: + DocumentService.delete_documents(dataset, []) + + mock_db.session.scalars.assert_not_called() + + def test_delete_documents_deletes_rows_and_dispatches_cleanup_task(self): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock(doc_form="text_model") + document_a = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-1"}, + ) + document_b = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-2", + data_source_type="upload_file", + data_source_info_dict={"upload_file_id": "file-2"}, + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.batch_clean_document_task") as clean_task, + ): + mock_db.session.scalars.return_value.all.return_value = [document_a, document_b] + + DocumentService.delete_documents(dataset, ["doc-1", "doc-2"]) + + assert mock_db.session.delete.call_count == 2 + mock_db.session.commit.assert_called_once() + clean_task.delay.assert_called_once_with(["doc-1", "doc-2"], dataset.id, dataset.doc_form, ["file-1", "file-2"]) + + def test_rename_document_raises_when_dataset_is_missing(self, rename_account_context): + with patch.object(DatasetService, "get_dataset", return_value=None): + with pytest.raises(ValueError, match="Dataset not found"): + DocumentService.rename_document("dataset-1", "doc-1", "New Name") + + def test_rename_document_raises_when_document_is_missing(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=None), + ): + with pytest.raises(ValueError, match="Document not found"): + DocumentService.rename_document(dataset.id, "doc-1", "New Name") + + def test_rename_document_rejects_cross_tenant_access(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock() + document = DatasetServiceUnitDataFactory.create_document_mock(tenant_id="tenant-other") + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=document), + ): + with pytest.raises(ValueError, match="No permission"): + DocumentService.rename_document(dataset.id, document.id, "New Name") + + def test_rename_document_updates_document_metadata_and_upload_file_name(self, rename_account_context): + dataset = DatasetServiceUnitDataFactory.create_dataset_mock( + built_in_field_enabled=True, + tenant_id="tenant-1", + ) + document = DatasetServiceUnitDataFactory.create_document_mock( + tenant_id="tenant-1", + doc_metadata={"title": "Old"}, + data_source_info_dict={"upload_file_id": "file-1"}, + ) + rename_account_context.current_tenant_id = "tenant-1" + + with ( + patch.object(DatasetService, "get_dataset", return_value=dataset), + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.db") as mock_db, + ): + result = DocumentService.rename_document(dataset.id, document.id, "New Name") + + assert result is document + assert document.name == "New Name" + assert document.doc_metadata[BuiltInField.document_name] == "New Name" + mock_db.session.add.assert_called_once_with(document) + mock_db.session.query.return_value.where.return_value.update.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_recover_document_raises_when_document_is_not_paused(self): + document = DatasetServiceUnitDataFactory.create_document_mock(is_paused=False) + + with pytest.raises(DocumentIndexingError): + DocumentService.recover_document(document) + + def test_retry_document_raises_when_retry_flag_is_already_set(self): + document = DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1") + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="being retried"): + DocumentService.retry_document("dataset-1", [document]) + + def test_sync_website_document_raises_when_sync_flag_exists(self): + document = DatasetServiceUnitDataFactory.create_document_mock(document_id="doc-1") + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="being synced"): + DocumentService.sync_website_document("dataset-1", document) + + def test_sync_website_document_updates_status_sets_cache_and_dispatches_task(self): + document = DatasetServiceUnitDataFactory.create_document_mock( + document_id="doc-1", + data_source_info_dict={"mode": "crawl"}, + ) + document.data_source_info = "{}" + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.sync_website_document_indexing_task") as sync_task, + ): + mock_redis.get.return_value = None + + DocumentService.sync_website_document("dataset-1", document) + + assert document.indexing_status == "waiting" + assert '"mode": "scrape"' in document.data_source_info + mock_db.session.add.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with("document_doc-1_is_sync", 600, 1) + sync_task.delay.assert_called_once_with("dataset-1", "doc-1") + + def test_get_documents_position_returns_next_position_when_documents_exist(self): + document = DatasetServiceUnitDataFactory.create_document_mock(position=7) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = ( + document + ) + + result = DocumentService.get_documents_position("dataset-1") + + assert result == 8 + + def test_get_documents_position_defaults_to_one_when_dataset_is_empty(self): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = None + + result = DocumentService.get_documents_position("dataset-1") + + assert result == 1 + + +class TestDocumentServiceSaveDocumentWithoutDatasetId: + """Unit tests for dataset creation around save_document_without_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_save_document_without_dataset_id_creates_high_quality_dataset_with_default_retrieval_model( + self, account_context + ): + knowledge_config = KnowledgeConfig( + indexing_technique="high_quality", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + embedding_model="embedding-model", + embedding_model_provider="provider", + summary_index_setting={"enable": True}, + is_multimodal=True, + ) + created_dataset = SimpleNamespace( + id="dataset-1", + tenant_id="tenant-1", + name="", + description=None, + ) + first_document = SimpleNamespace(name="VeryLongDocumentNameForDataset.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ), + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ) as dataset_cls, + patch.object( + DocumentService, "save_document_with_dataset_id", return_value=([first_document], "batch-1") + ) as save_document, + patch("services.dataset_service.db") as mock_db, + ): + dataset, documents, batch = DocumentService.save_document_without_dataset_id( + tenant_id="tenant-1", + knowledge_config=knowledge_config, + account=account_context, + ) + + assert dataset is created_dataset + assert documents == [first_document] + assert batch == "batch-1" + assert created_dataset.collection_binding_id == "binding-1" + assert created_dataset.retrieval_model["search_method"] == RetrievalMethod.SEMANTIC_SEARCH + assert created_dataset.retrieval_model["top_k"] == 4 + assert created_dataset.summary_index_setting == {"enable": True} + assert created_dataset.is_multimodal is True + assert created_dataset.name == first_document.name[:18] + "..." + assert ( + created_dataset.description + == "useful for when you want to answer queries about the VeryLongDocumentNameForDataset.txt" + ) + dataset_cls.assert_called_once() + save_document.assert_called_once_with(created_dataset, knowledge_config, account_context) + assert mock_db.session.commit.call_count == 1 + + def test_save_document_without_dataset_id_uses_provided_retrieval_model(self, account_context): + retrieval_model = RetrievalModel( + search_method=RetrievalMethod.SEMANTIC_SEARCH, + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + top_k=9, + score_threshold_enabled=True, + score_threshold=0.6, + ) + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + retrieval_model=retrieval_model, + ) + created_dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="", description=None) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ), + patch.object( + DocumentService, + "save_document_with_dataset_id", + return_value=([SimpleNamespace(name="Doc")], "batch-1"), + ), + patch("services.dataset_service.db"), + ): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + assert created_dataset.retrieval_model == retrieval_model.model_dump() + assert created_dataset.collection_binding_id is None + + def test_save_document_without_dataset_id_rejects_sandbox_batch_upload(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1", "file-2"]), + ) + ), + ) + + with ( + patch( + "services.dataset_service.FeatureService.get_features", + return_value=_make_features(enabled=True, plan=CloudPlan.SANDBOX), + ), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="does not support batch upload"): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_not_called() + + +class TestDocumentServiceUpdateDocumentWithDatasetId: + """Unit tests for the document-update orchestration path.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_update_document_with_dataset_id_raises_when_document_is_missing(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=None), + patch.object(DatasetService, "check_dataset_model_setting") as check_model_setting, + ): + with pytest.raises(NotFound, match="Document not found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + check_model_setting.assert_called_once_with(dataset) + + def test_update_document_with_dataset_id_rejects_non_available_documents(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = SimpleNamespace(display_status="indexing") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="Document is not available"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_upload_file_process_rule_and_name_override(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document.dataset_process_rule_id = "old-rule" + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="custom", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + name="Renamed document", + doc_form=IndexStructureType.QA_INDEX, + ) + created_process_rule = SimpleNamespace(id="rule-2") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.DatasetProcessRule", return_value=created_process_rule), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + upload_query = MagicMock() + upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 3 + mock_db.session.query.side_effect = [upload_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.dataset_process_rule_id == "rule-2" + assert document.data_source_type == "upload_file" + assert document.data_source_info == '{"upload_file_id": "file-1"}' + assert document.name == "Renamed document" + assert document.indexing_status == "waiting" + assert document.completed_at is None + assert document.processing_started_at is None + assert document.parsing_completed_at is None + assert document.cleaning_completed_at is None + assert document.splitting_completed_at is None + assert document.updated_at == "now" + assert document.created_from == "web" + assert document.doc_form == IndexStructureType.QA_INDEX + assert mock_db.session.commit.call_count == 3 + segment_query.filter_by.return_value.update.assert_called_once() + update_task.delay.assert_called_once_with(document.dataset_id, document.id) + + def test_update_document_with_dataset_id_notion_import_requires_binding(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = SimpleNamespace(display_status="available", id="doc-1", dataset_id="dataset-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page")], + ) + ], + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + ): + binding_query = MagicMock() + binding_query.where.return_value.first.return_value = None + mock_db.session.query.return_value = binding_query + + with pytest.raises(ValueError, match="Data source binding not found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_website_crawl_updates_segments_and_dispatches_task(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=["https://example.com"], + only_main_content=False, + ), + ) + ), + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 2 + mock_db.session.query.return_value = segment_query + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.data_source_type == "website_crawl" + assert document.data_source_info == ( + '{"url": "https://example.com", "provider": "firecrawl", "job_id": "job-1", ' + '"only_main_content": false, "mode": "crawl"}' + ) + assert document.name == "" + assert document.doc_form == IndexStructureType.PARENT_CHILD_INDEX + segment_query.filter_by.return_value.update.assert_called_once() + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + +class TestDocumentServiceCreateValidation: + """Unit tests for document creation validation helpers.""" + + def test_document_create_args_validate_requires_data_source_or_process_rule(self): + knowledge_config = SimpleNamespace(data_source=None, process_rule=None) + + with pytest.raises(ValueError, match="Data source or Process rule is required"): + DocumentService.document_create_args_validate(knowledge_config) + + def test_document_create_args_validate_delegates_to_sub_validators(self): + knowledge_config = SimpleNamespace(data_source=object(), process_rule=object()) + + with ( + patch.object(DocumentService, "data_source_args_validate") as validate_data_source, + patch.object(DocumentService, "process_rule_args_validate") as validate_process_rule, + ): + DocumentService.document_create_args_validate(knowledge_config) + + validate_data_source.assert_called_once_with(knowledge_config) + validate_process_rule.assert_called_once_with(knowledge_config) + + def test_data_source_args_validate_rejects_invalid_type(self): + knowledge_config = SimpleNamespace( + data_source=SimpleNamespace( + info_list=SimpleNamespace( + data_source_type="bad-source", + file_info_list=None, + notion_info_list=None, + website_info_list=None, + ) + ) + ) + + with pytest.raises(ValueError, match="Data source type is invalid"): + DocumentService.data_source_args_validate(knowledge_config) + + @pytest.mark.parametrize( + ("data_source_type", "field_name", "message"), + [ + ("upload_file", "file_info_list", "File source info is required"), + ("notion_import", "notion_info_list", "Notion source info is required"), + ("website_crawl", "website_info_list", "Website source info is required"), + ], + ) + def test_data_source_args_validate_requires_source_specific_info(self, data_source_type, field_name, message): + info_list = SimpleNamespace( + data_source_type=data_source_type, + file_info_list=object(), + notion_info_list=object(), + website_info_list=object(), + ) + setattr(info_list, field_name, None) + knowledge_config = SimpleNamespace(data_source=SimpleNamespace(info_list=info_list)) + + with pytest.raises(ValueError, match=message): + DocumentService.data_source_args_validate(knowledge_config) + + def test_process_rule_args_validate_clears_rules_for_automatic_mode(self): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="automatic", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + ) + + DocumentService.process_rule_args_validate(knowledge_config) + + assert knowledge_config.process_rule is not None + assert knowledge_config.process_rule.rules is None + + def test_process_rule_args_validate_deduplicates_rules_and_skips_max_tokens_for_full_doc_hierarchical(self): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="hierarchical", + rules=Rule( + pre_processing_rules=[ + PreProcessingRule(id="remove_stopwords", enabled=True), + PreProcessingRule(id="remove_stopwords", enabled=False), + ], + segmentation=Segmentation(separator="\n", max_tokens=0), + parent_mode="full-doc", + ), + ), + ) + + DocumentService.process_rule_args_validate(knowledge_config) + + assert knowledge_config.process_rule is not None + assert knowledge_config.process_rule.rules is not None + assert len(knowledge_config.process_rule.rules.pre_processing_rules) == 1 + assert knowledge_config.process_rule.rules.pre_processing_rules[0].enabled is False + + +class TestDocumentServiceSaveDocumentWithDatasetId: + """Unit tests for non-SQL validation branches in save_document_with_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.current_user", account), + patch.object(DatasetService, "check_doc_form"), + ): + yield account + + def test_save_document_with_dataset_id_requires_file_info_for_upload_source(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)): + with pytest.raises(ValueError, match="File source info is required"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_blocks_batch_upload_for_sandbox_plan(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch( + "services.dataset_service.FeatureService.get_features", + return_value=_make_features(enabled=True, plan=CloudPlan.SANDBOX), + ), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="does not support batch upload"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + check_quota.assert_not_called() + + def test_save_document_with_dataset_id_enforces_batch_upload_limit(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", 1), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="batch upload limit of 1"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + check_quota.assert_not_called() + + def test_save_document_with_dataset_id_updates_existing_document_and_data_source_type(self, account_context): + dataset = _make_dataset(data_source_type=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + updated_document = _make_document(document_id="doc-1", batch="batch-existing") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch.object( + DocumentService, "update_document_with_dataset_id", return_value=updated_document + ) as update_document, + ): + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert dataset.data_source_type == "upload_file" + assert documents == [updated_document] + assert batch == "batch-existing" + update_document.assert_called_once_with(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_data_source_for_new_documents(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(data_source=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="Data source is required when creating new documents"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_existing_process_rule_for_custom_mode(self, account_context): + dataset = _make_dataset(latest_process_rule=None) + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule(mode="custom"), + ) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="No process rule found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_rejects_invalid_indexing_technique(self, account_context): + dataset = _make_dataset(indexing_technique=None) + knowledge_config = SimpleNamespace( + doc_form=IndexStructureType.PARAGRAPH_INDEX, + original_document_id=None, + data_source=None, + indexing_technique="broken-technique", + ) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + with pytest.raises(ValueError, match="Indexing technique is invalid"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_returns_empty_for_invalid_process_rule_mode(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1"]) + knowledge_config.process_rule = SimpleNamespace(mode="unsupported-mode", rules=None) + + with patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)): + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [] + assert batch == "" + + def test_save_document_with_dataset_id_upload_file_creates_and_reindexes_documents(self, account_context): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + duplicate_document = _make_document(document_id="doc-duplicate", name="existing.txt") + created_document = _make_document(document_id="doc-created", name="new.txt") + upload_file_a = SimpleNamespace(id="file-1", name="existing.txt") + upload_file_b = SimpleNamespace(id="file-2", name="new.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=4), + patch.object(DocumentService, "build_document", return_value=created_document) as build_document, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + patch("services.dataset_service.DuplicateDocumentIndexingTaskProxy") as duplicate_proxy_cls, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [upload_file_a, upload_file_b] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [duplicate_document] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + documents, batch = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert documents == [duplicate_document, created_document] + assert batch == "20260101010101100023" + assert duplicate_document.dataset_process_rule_id == "rule-1" + assert duplicate_document.updated_at == "now" + assert duplicate_document.batch == batch + assert duplicate_document.indexing_status == "waiting" + build_document.assert_called_once_with( + dataset, + "rule-1", + "upload_file", + IndexStructureType.PARAGRAPH_INDEX, + "English", + {"upload_file_id": "file-2"}, + "web", + 4, + account_context, + "new.txt", + batch, + ) + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-created"]) + document_proxy_cls.return_value.delay.assert_called_once() + duplicate_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-duplicate"]) + duplicate_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_notion_import_truncates_names_and_cleans_removed_pages( + self, account_context + ): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + notion_page_name = "a" * 300 + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-keep", page_name="Keep page", type="page"), + NotionPage( + page_id="page-new", + page_name=notion_page_name, + page_icon=NotionIcon(type="emoji", emoji="page"), + type="page", + ), + ], + ) + ], + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + existing_keep = _make_document(document_id="doc-keep") + existing_keep.data_source_info = json.dumps({"notion_page_id": "page-keep"}) + existing_remove = _make_document(document_id="doc-remove") + existing_remove.data_source_info = json.dumps({"notion_page_id": "page-remove"}) + created_document = _make_document(document_id="doc-new") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document) as build_document, + patch("services.dataset_service.clean_notion_document_task") as clean_task, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + ): + mock_redis.lock.return_value = _make_lock_context() + notion_documents_query = MagicMock() + notion_documents_query.filter_by.return_value.all.return_value = [existing_keep, existing_remove] + mock_db.session.query.return_value = notion_documents_query + + documents, _ = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert created_document in documents + assert len(build_document.call_args.args[9]) == 255 + clean_task.delay.assert_called_once_with(["doc-remove"], dataset.id) + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-new"]) + document_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_website_crawl_truncates_long_urls(self, account_context): + dataset = _make_dataset() + dataset_process_rule = SimpleNamespace(id="rule-1") + long_url = "https://example.com/" + ("a" * 260) + short_url = "https://example.com/short" + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=[long_url, short_url], + only_main_content=True, + ), + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + first_document = _make_document(document_id="doc-1") + second_document = _make_document(document_id="doc-2") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=2), + patch.object( + DocumentService, + "build_document", + side_effect=[first_document, second_document], + ) as build_document, + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + ): + mock_redis.lock.return_value = _make_lock_context() + + documents, _ = DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=dataset_process_rule, + ) + + assert documents == [first_document, second_document] + assert build_document.call_args_list[0].args[9] == long_url[:200] + "..." + assert build_document.call_args_list[1].args[9] == short_url + document_proxy_cls.assert_called_once_with(dataset.tenant_id, dataset.id, ["doc-1", "doc-2"]) + document_proxy_cls.return_value.delay.assert_called_once() + + +class TestDocumentServiceBatchUpdateStatus: + """Unit tests for batch_update_document_status orchestration and helper branches.""" + + def test_prepare_disable_update_requires_completed_document(self): + document = _make_document(indexing_status="waiting") + document.completed_at = None + + with pytest.raises(DocumentIndexingError, match="is not completed"): + DocumentService._prepare_disable_update(document, user=SimpleNamespace(id="user-1"), now="now") + + def test_prepare_archive_update_sets_async_task_for_enabled_document(self): + document = _make_document(enabled=True, archived=False) + + result = DocumentService._prepare_archive_update(document, user=SimpleNamespace(id="user-1"), now="now") + + assert result is not None + assert result["updates"]["archived"] is True + assert result["set_cache"] is True + assert result["async_task"]["args"] == [document.id] + + def test_prepare_unarchive_update_sets_async_task_for_enabled_document(self): + document = _make_document(enabled=True, archived=True) + + result = DocumentService._prepare_unarchive_update(document, now="now") + + assert result is not None + assert result["updates"]["archived"] is False + assert result["set_cache"] is True + assert result["async_task"]["args"] == [document.id] + + def test_batch_update_document_status_rejects_indexing_documents(self): + dataset = _make_dataset() + document = _make_document(name="Busy document") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + ): + mock_redis.get.return_value = "1" + + with pytest.raises(DocumentIndexingError, match="Busy document is being indexed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "archive", SimpleNamespace(id="user-1") + ) + + mock_db.session.commit.assert_not_called() + + def test_batch_update_document_status_rolls_back_when_commit_fails(self): + dataset = _make_dataset() + document = _make_document(enabled=False) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + ): + mock_redis.get.return_value = None + mock_db.session.commit.side_effect = RuntimeError("commit failed") + + with pytest.raises(RuntimeError, match="commit failed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "enable", SimpleNamespace(id="user-1") + ) + + mock_db.session.rollback.assert_called_once() + + def test_batch_update_document_status_raises_async_task_error_after_commit(self): + dataset = _make_dataset() + document = _make_document(enabled=False) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.add_document_to_index_task") as add_task, + ): + mock_redis.get.return_value = None + add_task.delay.side_effect = RuntimeError("task failed") + + with pytest.raises(RuntimeError, match="task failed"): + DocumentService.batch_update_document_status( + dataset, [document.id], "enable", SimpleNamespace(id="user-1") + ) + + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with(f"document_{document.id}_indexing", 600, 1) + + +class TestDocumentServiceTenantAndUpdateEdges: + """Unit tests for tenant-count and update edge cases.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_get_tenant_documents_count_returns_query_count(self, account_context): + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.count.return_value = 12 + + result = DocumentService.get_tenant_documents_count() + + assert result == 12 + mock_db.session.query.return_value.where.return_value.count.assert_called_once() + + def test_update_document_with_dataset_id_uses_automatic_process_rule_payload(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + process_rule=ProcessRule( + mode="automatic", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + ) + created_process_rule = SimpleNamespace(id="rule-2") + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 1 + mock_db.session.query.side_effect = [upload_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.dataset_process_rule_id == "rule-2" + assert document.name == "upload.txt" + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + assert mock_db.session.commit.call_count == 3 + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + def test_update_document_with_dataset_id_requires_upload_file_info(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="upload_file")), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="No file info list found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_raises_when_upload_file_is_missing(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="upload_file", + file_info_list=FileInfo(file_ids=["file-1"]), + ) + ), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + ): + mock_db.session.query.return_value.where.return_value.first.return_value = None + + with pytest.raises(FileNotExistsError): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_requires_notion_info_list(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="notion_import")), + ) + + with ( + patch.object(DocumentService, "get_document", return_value=_make_document()), + patch.object(DatasetService, "check_dataset_model_setting"), + ): + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + def test_update_document_with_dataset_id_notion_import_updates_page_info(self, account_context): + dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1") + document = _make_document() + document_data = KnowledgeConfig( + original_document_id="doc-1", + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page"), + NotionPage(page_id="page-2", page_name="Page 2", page_icon=None, type="database"), + ], + ) + ], + ) + ), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + ) + + with ( + patch.object(DocumentService, "get_document", return_value=document), + patch.object(DatasetService, "check_dataset_model_setting"), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.document_indexing_update_task") as update_task, + ): + binding_query = MagicMock() + binding_query.where.return_value.first.return_value = SimpleNamespace(id="binding-1") + segment_query = MagicMock() + segment_query.filter_by.return_value.update.return_value = 1 + mock_db.session.query.side_effect = [binding_query, segment_query] + + result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context) + + assert result is document + assert document.data_source_type == "notion_import" + assert document.name == "" + assert document.data_source_info == json.dumps( + { + "credential_id": "credential-1", + "notion_workspace_id": "workspace-1", + "notion_page_id": "page-2", + "notion_page_icon": None, + "type": "database", + } + ) + update_task.delay.assert_called_once_with("dataset-1", "doc-1") + + +class TestDocumentServiceSaveWithoutDatasetBilling: + """Unit tests for batch-count and quota branches in save_document_without_dataset_id.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_save_document_without_dataset_id_counts_notion_pages_for_quota(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="notion_import", + notion_info_list=[ + NotionInfo( + credential_id="credential-1", + workspace_id="workspace-1", + pages=[ + NotionPage(page_id="page-1", page_name="Page 1", page_icon=None, type="page"), + NotionPage(page_id="page-2", page_name="Page 2", page_icon=None, type="page"), + ], + ), + NotionInfo( + credential_id="credential-2", + workspace_id="workspace-2", + pages=[NotionPage(page_id="page-3", page_name="Page 3", page_icon=None, type="page")], + ), + ], + ) + ), + ) + created_dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1", name="", description=None) + features = _make_features(enabled=True) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=features), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", "10"), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + patch( + "services.dataset_service.Dataset", + side_effect=lambda **kwargs: created_dataset.__dict__.update(kwargs) or created_dataset, + ), + patch.object( + DocumentService, + "save_document_with_dataset_id", + return_value=([SimpleNamespace(name="Doc")], "batch-1"), + ), + patch("services.dataset_service.db"), + ): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_called_once_with(3, features) + + def test_save_document_without_dataset_id_enforces_batch_limit_for_website_urls(self, account_context): + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource( + info_list=InfoList( + data_source_type="website_crawl", + website_info_list=WebsiteInfo( + provider="firecrawl", + job_id="job-1", + urls=["https://example.com/a", "https://example.com/b"], + only_main_content=True, + ), + ) + ), + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=True)), + patch("services.dataset_service.dify_config.BATCH_UPLOAD_LIMIT", "1"), + patch.object(DocumentService, "check_documents_upload_quota") as check_quota, + ): + with pytest.raises(ValueError, match="batch upload limit of 1"): + DocumentService.save_document_without_dataset_id("tenant-1", knowledge_config, account_context) + + check_quota.assert_not_called() + + +class TestDocumentServiceEstimateValidation: + """Unit tests for estimate_args_validate branches.""" + + def test_estimate_args_validate_rejects_missing_info_list(self): + with pytest.raises(ValueError, match="Data source info is required"): + DocumentService.estimate_args_validate({}) + + def test_estimate_args_validate_sets_empty_rules_for_automatic_mode(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": {"mode": "automatic", "rules": {"ignored": True}}, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"] == {} + + def test_estimate_args_validate_rejects_unknown_pre_processing_rule_id(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [{"id": "unknown", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + }, + } + + with pytest.raises(ValueError, match="pre_processing_rules id is invalid"): + DocumentService.estimate_args_validate(args) + + def test_estimate_args_validate_deduplicates_rules_for_custom_mode(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [ + {"id": "remove_stopwords", "enabled": True}, + {"id": "remove_stopwords", "enabled": False}, + ], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + }, + } + + DocumentService.estimate_args_validate(args) + + assert args["process_rule"]["rules"]["pre_processing_rules"] == [{"id": "remove_stopwords", "enabled": False}] + + def test_estimate_args_validate_requires_summary_index_provider_name(self): + args = { + "info_list": {"data_source_type": "upload_file"}, + "process_rule": { + "mode": "custom", + "rules": { + "pre_processing_rules": [{"id": "remove_stopwords", "enabled": True}], + "segmentation": {"separator": "\n", "max_tokens": 128}, + }, + "summary_index_setting": {"enable": True, "model_name": "summary-model"}, + }, + } + + with pytest.raises(ValueError, match="Summary index model provider name is required"): + DocumentService.estimate_args_validate(args) + + +class TestDocumentServiceSaveDocumentAdditionalBranches: + """Additional unit tests for dataset bootstrap and process-rule branches.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with ( + patch("services.dataset_service.current_user", account), + patch.object(DatasetService, "check_doc_form"), + ): + yield account + + def test_save_document_with_dataset_id_initializes_high_quality_dataset_from_default_embedding_model( + self, account_context + ): + dataset = _make_dataset(data_source_type=None, indexing_technique=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + knowledge_config.indexing_technique = "high_quality" + knowledge_config.embedding_model = None + knowledge_config.embedding_model_provider = None + updated_document = _make_document(batch="batch-existing") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-1"), + ) as get_binding, + patch.object(DocumentService, "update_document_with_dataset_id", return_value=updated_document), + ): + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = SimpleNamespace( + model_name="default-embedding", + provider="default-provider", + ) + + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [updated_document] + assert batch == "batch-existing" + assert dataset.data_source_type == "upload_file" + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "default-embedding" + assert dataset.embedding_model_provider == "default-provider" + assert dataset.collection_binding_id == "binding-1" + assert dataset.retrieval_model == { + "search_method": "semantic_search", + "reranking_enable": False, + "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, + "top_k": 4, + "score_threshold_enabled": False, + } + get_binding.assert_called_once_with("default-provider", "default-embedding") + + def test_save_document_with_dataset_id_uses_explicit_embedding_and_retrieval_model(self, account_context): + dataset = _make_dataset(indexing_technique=None) + knowledge_config = _make_upload_knowledge_config(original_document_id="doc-1", file_ids=["file-1"]) + knowledge_config.indexing_technique = "high_quality" + knowledge_config.embedding_model = "explicit-model" + knowledge_config.embedding_model_provider = "explicit-provider" + knowledge_config.retrieval_model = RetrievalModel( + search_method="semantic_search", + reranking_enable=True, + reranking_model=RerankingModel( + reranking_provider_name="rerank-provider", + reranking_model_name="rerank-model", + ), + top_k=7, + score_threshold_enabled=True, + score_threshold=0.3, + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding", + return_value=SimpleNamespace(id="binding-2"), + ) as get_binding, + patch.object(DocumentService, "update_document_with_dataset_id", return_value=_make_document()), + ): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_not_called() + get_binding.assert_called_once_with("explicit-provider", "explicit-model") + assert dataset.embedding_model == "explicit-model" + assert dataset.embedding_model_provider == "explicit-provider" + assert dataset.retrieval_model == knowledge_config.retrieval_model.model_dump() + + def test_save_document_with_dataset_id_creates_custom_process_rule_for_new_upload_document(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule( + mode="custom", + rules=Rule( + pre_processing_rules=[PreProcessingRule(id="remove_stopwords", enabled=True)], + segmentation=Segmentation(separator="\n", max_tokens=128), + ), + ), + ) + created_process_rule = SimpleNamespace(id="rule-custom") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=3), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls, + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert documents == [created_document] + assert batch == "20260101010101100023" + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "custom", + "rules": knowledge_config.process_rule.rules.model_dump_json(), + "created_by": "user-1", + } + document_proxy_cls.assert_called_once_with("tenant-1", "dataset-1", ["doc-created"]) + document_proxy_cls.return_value.delay.assert_called_once() + + def test_save_document_with_dataset_id_creates_automatic_process_rule_for_new_upload_document( + self, account_context + ): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config( + file_ids=["file-1"], + process_rule=ProcessRule(mode="automatic"), + ) + created_process_rule = SimpleNamespace(id="rule-auto") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + assert mock_db.session.flush.call_count >= 2 + + def test_save_document_with_dataset_id_creates_fallback_automatic_process_rule_when_latest_is_missing( + self, account_context + ): + dataset = _make_dataset(latest_process_rule=None) + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1"], process_rule=None) + created_process_rule = SimpleNamespace(id="rule-fallback") + created_document = _make_document(document_id="doc-created", name="file.txt") + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.DatasetProcessRule") as process_rule_cls, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch.object(DocumentService, "build_document", return_value=created_document), + patch("services.dataset_service.DocumentIndexingTaskProxy"), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES + process_rule_cls.return_value = created_process_rule + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + existing_documents_query = MagicMock() + existing_documents_query.where.return_value.all.return_value = [] + mock_db.session.query.side_effect = [upload_query, existing_documents_query] + + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + assert process_rule_cls.call_args.kwargs == { + "dataset_id": "dataset-1", + "mode": "automatic", + "rules": json.dumps(DatasetProcessRule.AUTOMATIC_RULES), + "created_by": "user-1", + } + + def test_save_document_with_dataset_id_raises_when_upload_file_lookup_is_incomplete(self, account_context): + dataset = _make_dataset() + knowledge_config = _make_upload_knowledge_config(file_ids=["file-1", "file-2"]) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch.object(DocumentService, "get_documents_position", return_value=1), + patch("services.dataset_service.time.strftime", return_value="20260101010101"), + patch("services.dataset_service.secrets.randbelow", return_value=23), + ): + mock_redis.lock.return_value = _make_lock_context() + upload_query = MagicMock() + upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")] + mock_db.session.query.return_value = upload_query + + with pytest.raises(FileNotExistsError, match="One or more files not found"): + DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context) + + def test_save_document_with_dataset_id_requires_notion_info_list_for_notion_import(self, account_context): + dataset = _make_dataset() + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="notion_import")), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch.object(DocumentService, "get_documents_position", return_value=1), + ): + mock_redis.lock.return_value = _make_lock_context() + with pytest.raises(ValueError, match="No notion info list found"): + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=SimpleNamespace(id="rule-1"), + ) + + def test_save_document_with_dataset_id_requires_website_info_list_for_website_crawl(self, account_context): + dataset = _make_dataset() + knowledge_config = KnowledgeConfig( + indexing_technique="economy", + data_source=DataSource(info_list=InfoList(data_source_type="website_crawl")), + doc_form=IndexStructureType.PARAGRAPH_INDEX, + doc_language="English", + ) + + with ( + patch("services.dataset_service.FeatureService.get_features", return_value=_make_features(enabled=False)), + patch("services.dataset_service.redis_client") as mock_redis, + patch.object(DocumentService, "get_documents_position", return_value=1), + ): + mock_redis.lock.return_value = _make_lock_context() + with pytest.raises(ValueError, match="No website info list found"): + DocumentService.save_document_with_dataset_id( + dataset, + knowledge_config, + account_context, + dataset_process_rule=SimpleNamespace(id="rule-1"), + ) diff --git a/api/tests/unit_tests/services/test_dataset_service_segment.py b/api/tests/unit_tests/services/test_dataset_service_segment.py new file mode 100644 index 0000000000..2f8ae14a8e --- /dev/null +++ b/api/tests/unit_tests/services/test_dataset_service_segment.py @@ -0,0 +1,1017 @@ +"""Unit tests for SegmentService behaviors in dataset_service.""" + +from .dataset_service_test_helpers import ( + Account, + ChildChunk, + ChildChunkDeleteIndexError, + ChildChunkIndexingError, + ChildChunkUpdateArgs, + DocumentSegment, + IndexStructureType, + MagicMock, + ModelType, + SegmentService, + SegmentUpdateArgs, + SimpleNamespace, + _make_child_chunk, + _make_dataset, + _make_document, + _make_lock_context, + _make_segment, + create_autospec, + patch, + pytest, +) + + +class TestSegmentServiceChildChunks: + """Unit tests for child-chunk CRUD helpers.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_create_child_chunk_assigns_next_position_and_commits(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.lock.return_value = _make_lock_context() + mock_db.session.query.return_value.where.return_value.scalar.return_value = 2 + + child_chunk = SegmentService.create_child_chunk("child content", segment, document, dataset) + + assert isinstance(child_chunk, ChildChunk) + assert child_chunk.position == 3 + assert child_chunk.index_node_id == "node-1" + assert child_chunk.index_node_hash == "hash-1" + assert child_chunk.word_count == len("child content") + mock_db.session.add.assert_called_once_with(child_chunk) + vector_service.create_child_chunk_vector.assert_called_once_with(child_chunk, dataset) + mock_db.session.commit.assert_called_once() + + def test_create_child_chunk_rolls_back_and_raises_on_vector_failure(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.lock.return_value = _make_lock_context() + mock_db.session.query.return_value.where.return_value.scalar.return_value = None + vector_service.create_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.create_child_chunk("child content", segment, document, dataset) + + mock_db.session.rollback.assert_called_once() + mock_db.session.commit.assert_not_called() + + def test_update_child_chunks_updates_deletes_and_creates_records(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + existing_a = ChildChunk( + id="child-a", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=1, + content="old content", + word_count=11, + created_by="user-1", + ) + existing_b = ChildChunk( + id="child-b", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + segment_id="segment-1", + position=2, + content="remove me", + word_count=9, + created_by="user-1", + ) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.uuid.uuid4", return_value="node-new"), + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-new"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_db.session.scalars.return_value.all.return_value = [existing_a, existing_b] + + result = SegmentService.update_child_chunks( + [ + ChildChunkUpdateArgs(id="child-a", content="updated content"), + ChildChunkUpdateArgs(content="brand new"), + ], + segment, + document, + dataset, + ) + + assert [chunk.position for chunk in result] == [1, 3] + assert existing_a.content == "updated content" + assert existing_a.updated_by == account_context.id + assert existing_a.updated_at == "now" + mock_db.session.bulk_save_objects.assert_called_once_with([existing_a]) + mock_db.session.delete.assert_called_once_with(existing_b) + new_chunk = result[1] + assert isinstance(new_chunk, ChildChunk) + assert new_chunk.position == 3 + assert new_chunk.index_node_id == "node-new" + vector_service.update_child_chunk_vector.assert_called_once_with( + [new_chunk], [existing_a], [existing_b], dataset + ) + mock_db.session.commit.assert_called_once() + + def test_update_child_chunks_rolls_back_on_vector_failure(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + document = _make_document() + segment = _make_segment() + existing_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_db.session.scalars.return_value.all.return_value = [existing_chunk] + vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.update_child_chunks( + [ChildChunkUpdateArgs(id="child-a", content="updated content")], + segment, + document, + dataset, + ) + + mock_db.session.rollback.assert_called_once() + + def test_update_child_chunk_updates_vector_and_commits(self, account_context): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + result = SegmentService.update_child_chunk( + "new content", child_chunk, _make_segment(), _make_document(), dataset + ) + + assert result is child_chunk + assert child_chunk.content == "new content" + assert child_chunk.word_count == len("new content") + assert child_chunk.updated_by == "user-1" + assert child_chunk.updated_at == "now" + mock_db.session.add.assert_called_once_with(child_chunk) + vector_service.update_child_chunk_vector.assert_called_once_with([], [child_chunk], [], dataset) + mock_db.session.commit.assert_called_once() + + def test_delete_child_chunk_raises_delete_index_error_on_vector_failure(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + vector_service.delete_child_chunk_vector.side_effect = RuntimeError("delete failed") + + with pytest.raises(ChildChunkDeleteIndexError, match="delete failed"): + SegmentService.delete_child_chunk(child_chunk, dataset) + + mock_db.session.delete.assert_called_once_with(child_chunk) + mock_db.session.rollback.assert_called_once() + + +class TestSegmentServiceQueries: + """Unit tests for child-chunk and segment query helpers.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_get_child_chunks_applies_keyword_filter_and_paginate(self, account_context): + paginated = SimpleNamespace(items=["chunk"], total=1) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like, + ): + mock_db.paginate.return_value = paginated + + result = SegmentService.get_child_chunks( + segment_id="segment-1", + document_id="doc-1", + dataset_id="dataset-1", + page=2, + limit=10, + keyword="needle", + ) + + assert result is paginated + escape_like.assert_called_once_with("needle") + mock_db.paginate.assert_called_once() + + def test_get_child_chunk_by_id_returns_only_child_chunk_instances(self): + child_chunk = _make_child_chunk() + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = child_chunk + result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") + + assert result is child_chunk + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() + result = SegmentService.get_child_chunk_by_id("child-a", "tenant-1") + + assert result is None + + def test_get_segments_uses_status_and_keyword_filters(self): + paginated = SimpleNamespace(items=["segment"], total=1) + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.helper.escape_like_pattern", return_value="escaped") as escape_like, + ): + mock_db.paginate.return_value = paginated + + items, total = SegmentService.get_segments( + document_id="doc-1", + tenant_id="tenant-1", + status_list=["completed"], + keyword="needle", + page=1, + limit=20, + ) + + assert items == ["segment"] + assert total == 1 + escape_like.assert_called_once_with("needle") + mock_db.paginate.assert_called_once() + + def test_get_segment_by_id_returns_only_document_segment_instances(self): + segment = DocumentSegment( + id="segment-1", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + position=1, + content="segment", + word_count=7, + tokens=2, + created_by="user-1", + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = segment + result = SegmentService.get_segment_by_id("segment-1", "tenant-1") + + assert result is segment + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.query.return_value.where.return_value.first.return_value = SimpleNamespace() + result = SegmentService.get_segment_by_id("segment-1", "tenant-1") + + assert result is None + + def test_get_segments_by_document_and_dataset_returns_scalars_result(self): + segment = DocumentSegment( + id="segment-1", + tenant_id="tenant-1", + dataset_id="dataset-1", + document_id="doc-1", + position=1, + content="segment", + word_count=7, + tokens=2, + created_by="user-1", + ) + + with patch("services.dataset_service.db") as mock_db: + mock_db.session.scalars.return_value.all.return_value = [segment] + + result = SegmentService.get_segments_by_document_and_dataset( + document_id="doc-1", + dataset_id="dataset-1", + status="completed", + enabled=True, + ) + + assert result == [segment] + mock_db.session.scalars.assert_called_once() + + +class TestSegmentServiceValidation: + """Unit tests for segment-create argument validation.""" + + def test_segment_create_args_validate_requires_answer_for_qa_model(self): + document = _make_document(doc_form=IndexStructureType.QA_INDEX) + + with pytest.raises(ValueError, match="Answer is required"): + SegmentService.segment_create_args_validate({"content": "question"}, document) + + def test_segment_create_args_validate_requires_non_empty_content(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + + with pytest.raises(ValueError, match="Content is empty"): + SegmentService.segment_create_args_validate({"content": " "}, document) + + def test_segment_create_args_validate_enforces_attachment_limit(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + args = {"content": "hello", "attachment_ids": ["a-1", "a-2"]} + + with patch("services.dataset_service.dify_config.SINGLE_CHUNK_ATTACHMENT_LIMIT", 1): + with pytest.raises(ValueError, match="Exceeded maximum attachment limit of 1"): + SegmentService.segment_create_args_validate(args, document) + + def test_segment_create_args_validate_requires_attachment_ids_list(self): + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX) + + with pytest.raises(ValueError, match="Attachment IDs is invalid"): + SegmentService.segment_create_args_validate({"content": "hello", "attachment_ids": "bad-type"}, document) + + +class TestSegmentServiceMutations: + """Unit tests for segment create, update, delete, and bulk status flows.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_create_segment_creates_bindings_and_marks_segment_error_on_vector_failure(self, account_context): + dataset = _make_dataset(indexing_technique="economy") + document = _make_document( + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + doc_form=IndexStructureType.QA_INDEX, + word_count=0, + ) + refreshed_segment = SimpleNamespace(id="segment-1") + args = { + "content": "question", + "answer": "answer", + "keywords": ["kw-1"], + "attachment_ids": ["att-1", "att-2"], + } + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.uuid.uuid4", return_value="node-1"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.lock.return_value = _make_lock_context() + + max_position_query = MagicMock() + max_position_query.where.return_value.scalar.return_value = 2 + refresh_query = MagicMock() + refresh_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [max_position_query, refresh_query] + + def add_side_effect(obj): + if obj.__class__.__name__ == "DocumentSegment" and getattr(obj, "id", None) is None: + obj.id = "segment-1" + + mock_db.session.add.side_effect = add_side_effect + vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") + + result = SegmentService.create_segment(args=args, document=document, dataset=dataset) + + created_segment = vector_service.create_segments_vector.call_args.args[1][0] + attachment_bindings = [ + call.args[0] + for call in mock_db.session.add.call_args_list + if call.args and call.args[0].__class__.__name__ == "SegmentAttachmentBinding" + ] + + assert result is refreshed_segment + assert created_segment.position == 3 + assert created_segment.answer == "answer" + assert created_segment.word_count == len("question") + len("answer") + assert created_segment.status == "error" + assert created_segment.enabled is False + assert created_segment.error == "vector failed" + assert document.word_count == len("question") + len("answer") + assert len(attachment_bindings) == 2 + assert {binding.attachment_id for binding in attachment_bindings} == {"att-1", "att-2"} + assert mock_db.session.commit.call_count == 3 + + def test_multi_create_segment_high_quality_marks_segments_error_when_vector_creation_fails(self, account_context): + dataset = _make_dataset(indexing_technique="high_quality") + document = _make_document( + dataset_id=dataset.id, + tenant_id=dataset.tenant_id, + doc_form=IndexStructureType.QA_INDEX, + word_count=5, + ) + segments = [ + {"content": "question-1", "answer": "answer-1", "keywords": ["k1"]}, + {"content": "question-2", "answer": "answer-2"}, + ] + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.side_effect = [[11], [13]] + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", side_effect=["hash-1", "hash-2"]), + patch("services.dataset_service.uuid.uuid4", side_effect=["node-1", "node-2"]), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.lock.return_value = _make_lock_context() + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + mock_db.session.query.return_value.where.return_value.scalar.return_value = 1 + vector_service.create_segments_vector.side_effect = RuntimeError("vector failed") + + result = SegmentService.multi_create_segment(segments, document, dataset) + + assert len(result) == 2 + assert [segment.position for segment in result] == [2, 3] + assert [segment.tokens for segment in result] == [11, 13] + assert all(segment.status == "error" for segment in result) + assert all(segment.enabled is False for segment in result) + assert all(segment.error == "vector failed" for segment in result) + assert document.word_count == 5 + sum(len(item["content"]) + len(item["answer"]) for item in segments) + vector_service.create_segments_vector.assert_called_once_with( + [["k1"], None], result, dataset, document.doc_form + ) + mock_db.session.commit.assert_called_once() + + def test_update_segment_disables_enabled_segment_and_dispatches_index_cleanup(self, account_context): + segment = _make_segment(enabled=True) + document = _make_document() + dataset = _make_dataset() + args = SegmentUpdateArgs(enabled=False) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.disable_segment_from_index_task") as disable_task, + ): + mock_redis.get.return_value = None + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is segment + assert segment.enabled is False + assert segment.disabled_at == "now" + assert segment.disabled_by == account_context.id + mock_db.session.add.assert_called_once_with(segment) + mock_db.session.commit.assert_called_once() + mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_indexing", 600, 1) + disable_task.delay.assert_called_once_with(segment.id) + + def test_update_segment_rejects_updates_for_disabled_segment(self, account_context): + segment = _make_segment(enabled=False) + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = None + + with pytest.raises(ValueError, match="Can't update disabled segment"): + SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset) + + def test_update_segment_rejects_when_indexing_cache_exists(self, account_context): + segment = _make_segment(enabled=True) + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="Segment is indexing"): + SegmentService.update_segment(SegmentUpdateArgs(content="new content"), segment, document, dataset) + + def test_update_segment_updates_keywords_for_same_content_segment(self, account_context): + segment = _make_segment(content="same content", keywords=["old"]) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=20) + dataset = _make_dataset() + refreshed_segment = SimpleNamespace(id=segment.id) + args = SegmentUpdateArgs(content="same content", keywords=["new"]) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + assert segment.keywords == ["new"] + vector_service.update_segment_vector.assert_called_once_with(["new"], segment, dataset) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_regenerates_child_chunks_and_updates_manual_summary(self, account_context): + segment = _make_segment(content="same content", word_count=len("same content")) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=20, + ) + dataset = _make_dataset(indexing_technique="high_quality") + refreshed_segment = SimpleNamespace(id=segment.id) + processing_rule = SimpleNamespace(id=document.dataset_process_rule_id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model_instance = object() + args = SegmentUpdateArgs( + content="same content", + regenerate_child_chunks=True, + summary="new summary", + ) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model_instance + + processing_rule_query = MagicMock() + processing_rule_query.where.return_value.first.return_value = processing_rule + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + vector_service.generate_child_chunks.assert_called_once_with( + segment, + document, + dataset, + embedding_model_instance, + processing_rule, + True, + ) + update_summary.assert_called_once_with(segment, dataset, "new summary") + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_auto_regenerates_summary_after_content_change(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.summary_index_setting = {"enable": True} + refreshed_segment = SimpleNamespace(id=segment.id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [9] + args = SegmentUpdateArgs(content="new content", keywords=["kw-1"]) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-1"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch( + "services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary" + ) as generate_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + assert segment.content == "new content" + assert segment.index_node_hash == "hash-1" + assert segment.tokens == 9 + assert document.word_count == 18 + vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset) + generate_summary.assert_called_once_with(segment, dataset, {"enable": True}) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_regenerates_summary_when_manual_summary_is_unchanged(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.PARAGRAPH_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.summary_index_setting = {"enable": True} + refreshed_segment = SimpleNamespace(id=segment.id) + existing_summary = SimpleNamespace(summary_content="same summary") + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [7] + args = SegmentUpdateArgs(content="new text", summary="same summary") + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-2"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch( + "services.summary_index_service.SummaryIndexService.generate_and_vectorize_summary" + ) as generate_summary, + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment(args, segment, document, dataset) + + assert result is refreshed_segment + generate_summary.assert_called_once_with(segment, dataset, {"enable": True}) + update_summary.assert_not_called() + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_delete_segment_removes_index_and_updates_document_word_count(self): + segment = _make_segment(word_count=4, index_node_id="parent-node") + document = _make_document(word_count=10) + dataset = _make_dataset() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.delete_segment_from_index_task") as delete_task, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.all.return_value = [("child-1",), ("child-2",)] + + SegmentService.delete_segment(segment, document, dataset) + + assert document.word_count == 6 + mock_redis.setex.assert_called_once_with(f"segment_{segment.id}_delete_indexing", 600, 1) + delete_task.delay.assert_called_once_with( + ["parent-node"], + dataset.id, + document.id, + [segment.id], + ["child-1", "child-2"], + ) + mock_db.session.delete.assert_called_once_with(segment) + mock_db.session.add.assert_called_once_with(document) + mock_db.session.commit.assert_called_once() + + def test_delete_segment_rejects_when_delete_is_already_in_progress(self): + segment = _make_segment() + document = _make_document() + dataset = _make_dataset() + + with patch("services.dataset_service.redis_client") as mock_redis: + mock_redis.get.return_value = "1" + + with pytest.raises(ValueError, match="Segment is deleting"): + SegmentService.delete_segment(segment, document, dataset) + + def test_delete_segments_removes_records_and_clamps_document_word_count(self): + dataset = _make_dataset() + document = _make_document(word_count=3) + current_user = SimpleNamespace(current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.delete_segment_from_index_task") as delete_task, + ): + segments_query = MagicMock() + segments_query.with_entities.return_value.where.return_value.all.return_value = [ + ("node-1", "segment-1", 2), + ("node-2", "segment-2", 5), + ] + child_query = MagicMock() + child_query.where.return_value.all.return_value = [("child-1",)] + delete_query = MagicMock() + delete_query.where.return_value.delete.return_value = 2 + mock_db.session.query.side_effect = [segments_query, child_query, delete_query] + + SegmentService.delete_segments(["segment-1", "segment-2"], document, dataset) + + assert document.word_count == 0 + mock_db.session.add.assert_called_once_with(document) + delete_task.delay.assert_called_once_with( + ["node-1", "node-2"], + dataset.id, + document.id, + ["segment-1", "segment-2"], + ["child-1"], + ) + delete_query.where.return_value.delete.assert_called_once() + mock_db.session.commit.assert_called_once() + + def test_update_segments_status_enables_only_segments_without_indexing_cache(self): + dataset = _make_dataset() + document = _make_document() + segment_a = _make_segment(segment_id="segment-a", enabled=False) + segment_b = _make_segment(segment_id="segment-b", enabled=False) + current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.enable_segments_to_index_task") as enable_task, + ): + mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b] + mock_redis.get.side_effect = [None, "1"] + + SegmentService.update_segments_status(["segment-a", "segment-b"], "enable", dataset, document) + + assert segment_a.enabled is True + assert segment_a.disabled_at is None + assert segment_a.disabled_by is None + assert segment_b.enabled is False + mock_db.session.add.assert_called_once_with(segment_a) + mock_db.session.commit.assert_called_once() + enable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id) + + def test_update_segments_status_disables_only_segments_without_indexing_cache(self): + dataset = _make_dataset() + document = _make_document() + segment_a = _make_segment(segment_id="segment-a", enabled=True) + segment_b = _make_segment(segment_id="segment-b", enabled=True) + current_user = SimpleNamespace(id="user-1", current_tenant_id="tenant-1") + + with ( + patch("services.dataset_service.current_user", current_user), + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.disable_segments_from_index_task") as disable_task, + ): + mock_db.session.scalars.return_value.all.return_value = [segment_a, segment_b] + mock_redis.get.side_effect = [None, "1"] + + SegmentService.update_segments_status(["segment-a", "segment-b"], "disable", dataset, document) + + assert segment_a.enabled is False + assert segment_a.disabled_at == "now" + assert segment_a.disabled_by == current_user.id + assert segment_b.enabled is True + mock_db.session.add.assert_called_once_with(segment_a) + mock_db.session.commit.assert_called_once() + disable_task.delay.assert_called_once_with(["segment-a"], dataset.id, document.id) + + +class TestSegmentServiceChildChunkTailHelpers: + """Unit tests for the remaining child-chunk helper branches.""" + + def test_update_child_chunk_rolls_back_on_vector_failure(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.current_user", SimpleNamespace(id="user-1")), + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + vector_service.update_child_chunk_vector.side_effect = RuntimeError("vector failed") + + with pytest.raises(ChildChunkIndexingError, match="vector failed"): + SegmentService.update_child_chunk( + "new content", child_chunk, SimpleNamespace(), SimpleNamespace(), dataset + ) + + mock_db.session.rollback.assert_called_once() + mock_db.session.commit.assert_not_called() + + def test_delete_child_chunk_commits_after_successful_vector_delete(self): + dataset = SimpleNamespace(id="dataset-1") + child_chunk = _make_child_chunk() + + with ( + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + SegmentService.delete_child_chunk(child_chunk, dataset) + + mock_db.session.delete.assert_called_once_with(child_chunk) + vector_service.delete_child_chunk_vector.assert_called_once_with(child_chunk, dataset) + mock_db.session.commit.assert_called_once() + + +class TestSegmentServiceAdditionalRegenerationBranches: + """Additional unit tests for segment update and regeneration edge cases.""" + + @pytest.fixture + def account_context(self): + account = create_autospec(Account, instance=True) + account.id = "user-1" + account.current_tenant_id = "tenant-1" + + with patch("services.dataset_service.current_user", account): + yield account + + def test_update_segment_same_content_updates_answer_and_document_word_count_for_qa_segments(self, account_context): + segment = _make_segment(content="question", word_count=8) + document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=20) + dataset = _make_dataset() + refreshed_segment = SimpleNamespace(id=segment.id) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="question", answer="new answer"), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + assert segment.answer == "new answer" + assert segment.word_count == len("question") + len("new answer") + assert document.word_count == 20 + (len("question") + len("new answer") - 8) + vector_service.update_segment_vector.assert_not_called() + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_content_change_uses_answer_when_counting_tokens_for_qa_segments(self, account_context): + segment = _make_segment(content="old", word_count=3) + document = _make_document(doc_form=IndexStructureType.QA_INDEX, word_count=10) + dataset = _make_dataset(indexing_technique="high_quality") + refreshed_segment = SimpleNamespace(id=segment.id) + embedding_model = MagicMock() + embedding_model.get_text_embedding_num_tokens.return_value = [21] + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-qa"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_model_instance.return_value = embedding_model + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = None + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [summary_query, refreshed_query] + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="new question", answer="new answer", keywords=["kw-1"]), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + embedding_model.get_text_embedding_num_tokens.assert_called_once_with(texts=["new questionnew answer"]) + assert segment.answer == "new answer" + assert segment.tokens == 21 + assert segment.word_count == len("new question") + len("new answer") + vector_service.update_segment_vector.assert_called_once_with(["kw-1"], segment, dataset) + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_content_change_parent_child_uses_default_embedding_and_ignores_summary_failures( + self, account_context + ): + segment = _make_segment(content="old", word_count=3) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=10, + ) + dataset = _make_dataset(indexing_technique="high_quality") + dataset.embedding_model_provider = None + refreshed_segment = SimpleNamespace(id=segment.id) + processing_rule = SimpleNamespace(id=document.dataset_process_rule_id) + existing_summary = SimpleNamespace(summary_content="old summary") + embedding_model_instance = object() + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.ModelManager") as model_manager_cls, + patch("services.dataset_service.VectorService") as vector_service, + patch("services.dataset_service.helper.generate_text_hash", return_value="hash-parent"), + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.summary_index_service.SummaryIndexService.update_summary_for_segment") as update_summary, + ): + mock_redis.get.return_value = None + model_manager_cls.for_tenant.return_value.get_default_model_instance.return_value = embedding_model_instance + update_summary.side_effect = RuntimeError("summary failed") + + processing_rule_query = MagicMock() + processing_rule_query.where.return_value.first.return_value = processing_rule + summary_query = MagicMock() + summary_query.where.return_value.first.return_value = existing_summary + refreshed_query = MagicMock() + refreshed_query.where.return_value.first.return_value = refreshed_segment + mock_db.session.query.side_effect = [processing_rule_query, summary_query, refreshed_query] + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="new parent content", regenerate_child_chunks=True, summary="new summary"), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + model_manager_cls.for_tenant.return_value.get_default_model_instance.assert_called_once_with( + tenant_id="tenant-1", + model_type=ModelType.TEXT_EMBEDDING, + ) + vector_service.generate_child_chunks.assert_called_once_with( + segment, + document, + dataset, + embedding_model_instance, + processing_rule, + True, + ) + update_summary.assert_called_once_with(segment, dataset, "new summary") + vector_service.update_multimodel_vector.assert_called_once_with(segment, [], dataset) + + def test_update_segment_same_content_parent_child_marks_segment_error_for_non_high_quality_dataset( + self, account_context + ): + segment = _make_segment(content="same content", word_count=len("same content")) + document = _make_document( + doc_form=IndexStructureType.PARENT_CHILD_INDEX, + word_count=20, + ) + dataset = _make_dataset(indexing_technique="economy") + refreshed_segment = SimpleNamespace(id=segment.id) + + with ( + patch("services.dataset_service.redis_client") as mock_redis, + patch("services.dataset_service.db") as mock_db, + patch("services.dataset_service.naive_utc_now", return_value="now"), + patch("services.dataset_service.VectorService") as vector_service, + ): + mock_redis.get.return_value = None + mock_db.session.query.return_value.where.return_value.first.return_value = refreshed_segment + + result = SegmentService.update_segment( + SegmentUpdateArgs(content="same content", regenerate_child_chunks=True), + segment, + document, + dataset, + ) + + assert result is refreshed_segment + assert segment.enabled is False + assert segment.disabled_at == "now" + assert segment.status == "error" + assert segment.error == "The knowledge base index technique is not high quality!" + vector_service.update_multimodel_vector.assert_not_called() diff --git a/api/tests/unit_tests/services/test_datasource_provider_service.py b/api/tests/unit_tests/services/test_datasource_provider_service.py index 105ef7ba48..3df7d500cf 100644 --- a/api/tests/unit_tests/services/test_datasource_provider_service.py +++ b/api/tests/unit_tests/services/test_datasource_provider_service.py @@ -1,10 +1,10 @@ from unittest.mock import MagicMock, patch import pytest +from graphon.model_runtime.entities.provider_entities import FormType from sqlalchemy.orm import Session from core.plugin.entities.plugin_daemon import CredentialType -from dify_graph.model_runtime.entities.provider_entities import FormType from models.account import Account from models.model import EndUser from models.oauth import DatasourceProvider diff --git a/api/tests/unit_tests/services/test_human_input_service.py b/api/tests/unit_tests/services/test_human_input_service.py index 375e47d7fc..9be475d043 100644 --- a/api/tests/unit_tests/services/test_human_input_service.py +++ b/api/tests/unit_tests/services/test_human_input_service.py @@ -3,18 +3,19 @@ from datetime import datetime, timedelta from unittest.mock import MagicMock import pytest +from graphon.nodes.human_input.entities import ( + FormDefinition, + FormInput, + UserAction, +) +from graphon.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus import services.human_input_service as human_input_service_module from core.repositories.human_input_repository import ( HumanInputFormRecord, HumanInputFormSubmissionRepository, ) -from dify_graph.nodes.human_input.entities import ( - FormDefinition, - FormInput, - UserAction, -) -from dify_graph.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus +from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from services.human_input_service import ( Form, @@ -51,11 +52,11 @@ def sample_form_record(): inputs=[], user_actions=[UserAction(id="submit", title="Submit")], rendered_content="

hello

", - expiration_time=datetime.utcnow() + timedelta(hours=1), + expiration_time=naive_utc_now() + timedelta(hours=1), ), rendered_content="

hello

", - created_at=datetime.utcnow(), - expiration_time=datetime.utcnow() + timedelta(hours=1), + created_at=naive_utc_now(), + expiration_time=naive_utc_now() + timedelta(hours=1), status=HumanInputFormStatus.WAITING, selected_action_id=None, submitted_data=None, @@ -101,8 +102,8 @@ def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_rec service = HumanInputService(session_factory) expired_record = dataclasses.replace( sample_form_record, - created_at=datetime.utcnow() - timedelta(hours=2), - expiration_time=datetime.utcnow() + timedelta(hours=2), + created_at=naive_utc_now() - timedelta(hours=2), + expiration_time=naive_utc_now() + timedelta(hours=2), ) monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600) @@ -391,7 +392,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): service = HumanInputService(session_factory) # Submitted - submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=naive_utc_now()) with pytest.raises(human_input_service_module.FormSubmittedError): service.ensure_form_active(Form(submitted_record)) @@ -402,7 +403,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): # Expired time expired_time_record = dataclasses.replace( - sample_form_record, expiration_time=datetime.utcnow() - timedelta(minutes=1) + sample_form_record, expiration_time=naive_utc_now() - timedelta(minutes=1) ) with pytest.raises(FormExpiredError): service.ensure_form_active(Form(expired_time_record)) @@ -411,7 +412,7 @@ def test_ensure_form_active_errors(sample_form_record, mock_session_factory): def test_ensure_not_submitted_raises(sample_form_record, mock_session_factory): session_factory, _ = mock_session_factory service = HumanInputService(session_factory) - submitted_record = dataclasses.replace(sample_form_record, submitted_at=datetime.utcnow()) + submitted_record = dataclasses.replace(sample_form_record, submitted_at=naive_utc_now()) with pytest.raises(human_input_service_module.FormSubmittedError): service._ensure_not_submitted(Form(submitted_record)) diff --git a/api/tests/unit_tests/services/test_knowledge_service.py b/api/tests/unit_tests/services/test_knowledge_service.py index bc0caee071..53c243ad71 100644 --- a/api/tests/unit_tests/services/test_knowledge_service.py +++ b/api/tests/unit_tests/services/test_knowledge_service.py @@ -3,7 +3,7 @@ from unittest.mock import MagicMock, patch import pytest -from services.knowledge_service import ExternalDatasetTestService +from services.knowledge_service import BedrockRetrievalSetting, ExternalDatasetTestService class TestKnowledgeService: @@ -24,7 +24,7 @@ class TestKnowledgeService: mock_client = MagicMock() mock_boto_client.return_value = mock_client - retrieval_setting = {"top_k": 4, "score_threshold": 0.5} + retrieval_setting = BedrockRetrievalSetting(top_k=4, score_threshold=0.5) query = "test query" knowledge_id = "kb-123" @@ -87,7 +87,10 @@ class TestKnowledgeService: mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 200}, "retrievalResults": []} # Act - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert result["records"] == [] @@ -104,7 +107,10 @@ class TestKnowledgeService: mock_client.retrieve.return_value = {"ResponseMetadata": {"HTTPStatusCode": 500}} # Act - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert result["records"] == [] @@ -114,7 +120,7 @@ class TestKnowledgeService: with patch("services.knowledge_service.boto3.client") as mock_boto: mock_boto.side_effect = Exception("client init failed") with pytest.raises(Exception) as exc_info: - ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb") + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb") assert "client init failed" in str(exc_info.value) # ===== Edge Cases ===== @@ -139,7 +145,10 @@ class TestKnowledgeService: # Act # retrieval_setting missing "score_threshold" - result = cast(dict[str, Any], ExternalDatasetTestService.knowledge_retrieval({"top_k": 1}, "query", "kb")) + result = cast( + dict[str, Any], + ExternalDatasetTestService.knowledge_retrieval(BedrockRetrievalSetting(top_k=1), "query", "kb"), + ) # Assert assert len(result["records"]) == 1 diff --git a/api/tests/unit_tests/services/test_message_service.py b/api/tests/unit_tests/services/test_message_service.py index e7740ef93a..101b9bff24 100644 --- a/api/tests/unit_tests/services/test_message_service.py +++ b/api/tests/unit_tests/services/test_message_service.py @@ -933,7 +933,7 @@ class TestMessageServiceSuggestedQuestions: ) # Test 28: get_suggested_questions_after_answer - Advanced Chat success - @patch("services.message_service.ModelManager") + @patch("services.message_service.ModelManager.for_tenant") @patch("services.message_service.WorkflowService") @patch("services.message_service.AdvancedChatAppConfigManager") @patch("services.message_service.TokenBufferMemory") @@ -983,7 +983,7 @@ class TestMessageServiceSuggestedQuestions: # Test 29: get_suggested_questions_after_answer - Chat app success (no override) @patch("services.message_service.db") - @patch("services.message_service.ModelManager") + @patch("services.message_service.ModelManager.for_tenant") @patch("services.message_service.TokenBufferMemory") @patch("services.message_service.LLMGenerator") @patch("services.message_service.TraceQueueManager") diff --git a/api/tests/unit_tests/services/test_metadata_partial_update.py b/api/tests/unit_tests/services/test_metadata_partial_update.py deleted file mode 100644 index 60252784bc..0000000000 --- a/api/tests/unit_tests/services/test_metadata_partial_update.py +++ /dev/null @@ -1,187 +0,0 @@ -import unittest -from unittest.mock import MagicMock, patch - -import pytest - -from models.dataset import Dataset, Document -from services.entities.knowledge_entities.knowledge_entities import ( - DocumentMetadataOperation, - MetadataDetail, - MetadataOperationData, -) -from services.metadata_service import MetadataService - - -class TestMetadataPartialUpdate(unittest.TestCase): - def setUp(self): - self.dataset = MagicMock(spec=Dataset) - self.dataset.id = "dataset_id" - self.dataset.built_in_field_enabled = False - - self.document = MagicMock(spec=Document) - self.document.id = "doc_id" - self.document.doc_metadata = {"existing_key": "existing_value"} - self.document.data_source_type = "upload_file" - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_partial_update_merges_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Mock DB query for existing bindings - - # No existing binding for new key - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Input data - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # 1. Check that doc_metadata contains BOTH existing and new keys - expected_metadata = {"existing_key": "existing_value", "new_key": "new_value"} - assert self.document.doc_metadata == expected_metadata - - # 2. Check that existing bindings were NOT deleted - # The delete call in the original code: db.session.query(...).filter_by(...).delete() - # In partial update, this should NOT be called. - mock_db.session.query.return_value.filter_by.return_value.delete.assert_not_called() - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_full_update_replaces_metadata(self, mock_redis, mock_current_account, mock_document_service, mock_db): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Input data (partial_update=False by default) - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="new_meta_id", name="new_key", value="new_value")], - partial_update=False, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # 1. Check that doc_metadata contains ONLY the new key - expected_metadata = {"new_key": "new_value"} - assert self.document.doc_metadata == expected_metadata - - # 2. Check that existing bindings WERE deleted - # In full update (default), we expect the existing bindings to be cleared. - mock_db.session.query.return_value.filter_by.return_value.delete.assert_called() - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_partial_update_skips_existing_binding( - self, mock_redis, mock_current_account, mock_document_service, mock_db - ): - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - - # Mock DB query to return an existing binding - # This simulates that the document ALREADY has the metadata we are trying to add - mock_existing_binding = MagicMock() - mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_existing_binding - - # Input data - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="existing_meta_id", name="existing_key", value="existing_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Execute - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify - # We verify that db.session.add was NOT called for DatasetMetadataBinding - # Since we can't easily check "not called with specific type" on the generic add method without complex logic, - # we can check if the number of add calls is 1 (only for the document update) instead of 2 (document + binding) - - # Expected calls: - # 1. db.session.add(document) - # 2. NO db.session.add(binding) because it exists - - # Note: In the code, db.session.add is called for document. - # Then loop over metadata_list. - # If existing_binding found, continue. - # So binding add should be skipped. - - # Let's filter the calls to add to see what was added - add_calls = mock_db.session.add.call_args_list - added_objects = [call.args[0] for call in add_calls] - - # Check that no DatasetMetadataBinding was added - from models.dataset import DatasetMetadataBinding - - has_binding_add = any( - isinstance(obj, DatasetMetadataBinding) - or (isinstance(obj, MagicMock) and getattr(obj, "__class__", None) == DatasetMetadataBinding) - for obj in added_objects - ) - - # Since we mock everything, checking isinstance might be tricky if DatasetMetadataBinding - # is not the exact class used in the service (imports match). - # But we can check the count. - # If it were added, there would be 2 calls. If skipped, 1 call. - assert mock_db.session.add.call_count == 1 - - @patch("services.metadata_service.db") - @patch("services.metadata_service.DocumentService") - @patch("services.metadata_service.current_account_with_tenant") - @patch("services.metadata_service.redis_client") - def test_rollback_called_on_commit_failure(self, mock_redis, mock_current_account, mock_document_service, mock_db): - """When db.session.commit() raises, rollback must be called and the exception must propagate.""" - # Setup mocks - mock_redis.get.return_value = None - mock_document_service.get_document.return_value = self.document - mock_current_account.return_value = (MagicMock(id="user_id"), "tenant_id") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Make commit raise an exception - mock_db.session.commit.side_effect = RuntimeError("database connection lost") - - operation = DocumentMetadataOperation( - document_id="doc_id", - metadata_list=[MetadataDetail(id="meta_id", name="key", value="value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act & Assert: the exception must propagate - with pytest.raises(RuntimeError, match="database connection lost"): - MetadataService.update_documents_metadata(self.dataset, metadata_args) - - # Verify rollback was called - mock_db.session.rollback.assert_called_once() - - # Verify the lock key was cleaned up despite the failure - mock_redis.delete.assert_called_with("document_metadata_lock_doc_id") - - -if __name__ == "__main__": - unittest.main() diff --git a/api/tests/unit_tests/services/test_metadata_service.py b/api/tests/unit_tests/services/test_metadata_service.py deleted file mode 100644 index bbdc16d4f8..0000000000 --- a/api/tests/unit_tests/services/test_metadata_service.py +++ /dev/null @@ -1,558 +0,0 @@ -from __future__ import annotations - -from dataclasses import dataclass -from datetime import UTC, datetime -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource -from models.dataset import Dataset -from services.entities.knowledge_entities.knowledge_entities import ( - DocumentMetadataOperation, - MetadataArgs, - MetadataDetail, - MetadataOperationData, -) -from services.metadata_service import MetadataService - - -@dataclass -class _DocumentStub: - id: str - name: str - uploader: str - upload_date: datetime - last_update_date: datetime - data_source_type: str - doc_metadata: dict[str, object] | None - - -@pytest.fixture -def mock_db(mocker: MockerFixture) -> MagicMock: - mocked_db = mocker.patch("services.metadata_service.db") - mocked_db.session = MagicMock() - return mocked_db - - -@pytest.fixture -def mock_redis_client(mocker: MockerFixture) -> MagicMock: - return mocker.patch("services.metadata_service.redis_client") - - -@pytest.fixture -def mock_current_account(mocker: MockerFixture) -> MagicMock: - mock_user = SimpleNamespace(id="user-1") - return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1")) - - -def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub: - now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC) - return _DocumentStub( - id=document_id, - name=f"doc-{document_id}", - uploader="qa@example.com", - upload_date=now, - last_update_date=now, - data_source_type="upload_file", - doc_metadata=doc_metadata, - ) - - -def _dataset(**kwargs: Any) -> Dataset: - return cast(Dataset, SimpleNamespace(**kwargs)) - - -def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name="x" * 256) - - # Act + Assert - with pytest.raises(ValueError, match="cannot exceed 255"): - MetadataService.create_metadata("dataset-1", metadata_args) - - -def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists( - mock_db: MagicMock, - mock_current_account: MagicMock, -) -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name="priority") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - MetadataService.create_metadata("dataset-1", metadata_args) - - # Assert - mock_current_account.assert_called_once() - - -def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name) - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="Built-in fields"): - MetadataService.create_metadata("dataset-1", metadata_args) - - -def test_create_metadata_should_persist_metadata_when_input_is_valid( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - metadata_args = MetadataArgs(type="number", name="score") - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act - result = MetadataService.create_metadata("dataset-1", metadata_args) - - # Assert - assert result.tenant_id == "tenant-1" - assert result.dataset_id == "dataset-1" - assert result.type == "number" - assert result.name == "score" - assert result.created_by == "user-1" - mock_db.session.add.assert_called_once_with(result) - mock_db.session.commit.assert_called_once() - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None: - # Arrange - too_long_name = "x" * 256 - - # Act + Assert - with pytest.raises(ValueError, match="cannot exceed 255"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name) - - -def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists( - mock_db: MagicMock, mock_current_account: MagicMock -) -> None: - # Arrange - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - - # Act + Assert - with pytest.raises(ValueError, match="already exists"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate") - - # Assert - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin( - mock_db: MagicMock, - mock_current_account: MagicMock, -) -> None: - # Arrange - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - - # Act + Assert - with pytest.raises(ValueError, match="Built-in fields"): - MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source) - - # Assert - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_update_bound_documents_and_return_metadata( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC) - mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now) - - metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None) - bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")] - query_duplicate = MagicMock() - query_duplicate.filter_by.return_value.first.return_value = None - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = metadata - query_bindings = MagicMock() - query_bindings.filter_by.return_value.all.return_value = bindings - mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings] - - doc_1 = _build_document("1", {"old_name": "value", "other": "keep"}) - doc_2 = _build_document("2", None) - mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids") - mock_get_documents.return_value = [doc_1, doc_2] - - # Act - result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name") - - # Assert - assert result is metadata - assert metadata.name == "new_name" - assert metadata.updated_by == "user-1" - assert metadata.updated_at == fixed_now - assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"} - assert doc_2.doc_metadata == {"new_name": None} - mock_get_documents.assert_called_once_with(["doc-1", "doc-2"]) - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - mock_current_account.assert_called_once() - - -def test_update_metadata_name_should_return_none_when_metadata_does_not_exist( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - mock_logger = mocker.patch("services.metadata_service.logger") - - query_duplicate = MagicMock() - query_duplicate.filter_by.return_value.first.return_value = None - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = None - mock_db.session.query.side_effect = [query_duplicate, query_metadata] - - # Act - result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name") - - # Assert - assert result is None - mock_logger.exception.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - mock_current_account.assert_called_once() - - -def test_delete_metadata_should_remove_metadata_and_related_document_fields( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - metadata = SimpleNamespace(id="metadata-1", name="obsolete") - bindings = [SimpleNamespace(document_id="doc-1")] - query_metadata = MagicMock() - query_metadata.filter_by.return_value.first.return_value = metadata - query_bindings = MagicMock() - query_bindings.filter_by.return_value.all.return_value = bindings - mock_db.session.query.side_effect = [query_metadata, query_bindings] - - document = _build_document("1", {"obsolete": "legacy", "remaining": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document]) - - # Act - result = MetadataService.delete_metadata("dataset-1", "metadata-1") - - # Assert - assert result is metadata - assert document.doc_metadata == {"remaining": "value"} - mock_db.session.delete.assert_called_once_with(metadata) - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_delete_metadata_should_return_none_when_metadata_is_missing( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - mock_db.session.query.return_value.filter_by.return_value.first.return_value = None - mock_logger = mocker.patch("services.metadata_service.logger") - - # Act - result = MetadataService.delete_metadata("dataset-1", "missing-id") - - # Assert - assert result is None - mock_logger.exception.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_get_built_in_fields_should_return_all_expected_fields() -> None: - # Arrange - expected_names = { - BuiltInField.document_name, - BuiltInField.uploader, - BuiltInField.upload_date, - BuiltInField.last_update_date, - BuiltInField.source, - } - - # Act - result = MetadataService.get_built_in_fields() - - # Assert - assert {item["name"] for item in result} == expected_names - assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"] - - -def test_enable_built_in_field_should_return_immediately_when_already_enabled( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") - - # Act - MetadataService.enable_built_in_field(dataset) - - # Assert - get_docs.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_enable_built_in_field_should_populate_documents_and_enable_flag( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - doc_1 = _build_document("1", {"custom": "value"}) - doc_2 = _build_document("2", None) - mocker.patch( - "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", - return_value=[doc_1, doc_2], - ) - - # Act - MetadataService.enable_built_in_field(dataset) - - # Assert - assert dataset.built_in_field_enabled is True - assert doc_1.doc_metadata is not None - assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1" - assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file - assert doc_2.doc_metadata is not None - assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com" - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_disable_built_in_field_should_return_immediately_when_already_disabled( - mock_db: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id") - - # Act - MetadataService.disable_built_in_field(dataset) - - # Assert - get_docs.assert_not_called() - mock_db.session.commit.assert_not_called() - - -def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - document = _build_document( - "1", - { - BuiltInField.document_name: "doc", - BuiltInField.uploader: "user", - BuiltInField.upload_date: 1.0, - BuiltInField.last_update_date: 2.0, - BuiltInField.source: MetadataDataSource.upload_file, - "custom": "keep", - }, - ) - mocker.patch( - "services.metadata_service.DocumentService.get_working_documents_by_dataset_id", - return_value=[document], - ) - - # Act - MetadataService.disable_built_in_field(dataset) - - # Assert - assert dataset.built_in_field_enabled is False - assert document.doc_metadata == {"custom": "keep"} - mock_db.session.commit.assert_called_once() - mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1") - - -def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - document = _build_document("1", {"legacy": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) - delete_chain = mock_db.session.query.return_value.filter_by.return_value - delete_chain.delete.return_value = 1 - operation = DocumentMetadataOperation( - document_id="1", - metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")], - partial_update=False, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - assert document.doc_metadata == {"priority": "high"} - delete_chain.delete.assert_called_once() - assert mock_db.session.commit.call_count == 1 - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") - mock_current_account.assert_called_once() - - -def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mock_current_account: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=True) - document = _build_document("1", {"existing": "value"}) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document) - mock_db.session.query.return_value.filter_by.return_value.first.return_value = object() - operation = DocumentMetadataOperation( - document_id="1", - metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")], - partial_update=True, - ) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - assert document.doc_metadata is not None - assert document.doc_metadata["existing"] == "value" - assert document.doc_metadata["new_key"] == "new_value" - assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file - assert mock_db.session.commit.call_count == 1 - assert mock_db.session.add.call_count == 1 - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1") - mock_current_account.assert_called_once() - - -def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found( - mock_db: MagicMock, - mock_redis_client: MagicMock, - mocker: MockerFixture, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - dataset = _dataset(id="dataset-1", built_in_field_enabled=False) - mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None) - operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True) - metadata_args = MetadataOperationData(operation_data=[operation]) - - # Act + Assert - with pytest.raises(ValueError, match="Document not found"): - MetadataService.update_documents_metadata(dataset, metadata_args) - - # Assert - mock_db.session.rollback.assert_called_once() - mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404") - - -@pytest.mark.parametrize( - ("dataset_id", "document_id", "expected_key"), - [ - ("dataset-1", None, "dataset_metadata_lock_dataset-1"), - (None, "doc-1", "document_metadata_lock_doc-1"), - ], -) -def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked( - dataset_id: str | None, - document_id: str | None, - expected_key: str, - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = None - - # Act - MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id) - - # Assert - mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600) - - -def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = 1 - - # Act + Assert - with pytest.raises(ValueError, match="knowledge base metadata operation is running"): - MetadataService.knowledge_base_metadata_lock_check("dataset-1", None) - - -def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists( - mock_redis_client: MagicMock, -) -> None: - # Arrange - mock_redis_client.get.return_value = 1 - - # Act + Assert - with pytest.raises(ValueError, match="document metadata operation is running"): - MetadataService.knowledge_base_metadata_lock_check(None, "doc-1") - - -def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None: - # Arrange - dataset = _dataset( - id="dataset-1", - built_in_field_enabled=True, - doc_metadata=[ - {"id": "meta-1", "name": "priority", "type": "string"}, - {"id": "built-in", "name": "ignored", "type": "string"}, - {"id": "meta-2", "name": "score", "type": "number"}, - ], - ) - count_chain = mock_db.session.query.return_value.filter_by.return_value - count_chain.count.side_effect = [3, 1] - - # Act - result = MetadataService.get_dataset_metadatas(dataset) - - # Assert - assert result["built_in_field_enabled"] is True - assert result["doc_metadata"] == [ - {"id": "meta-1", "name": "priority", "type": "string", "count": 3}, - {"id": "meta-2", "name": "score", "type": "number", "count": 1}, - ] - - -def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None: - # Arrange - dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None) - - # Act - result = MetadataService.get_dataset_metadatas(dataset) - - # Assert - assert result == {"doc_metadata": [], "built_in_field_enabled": False} - mock_db.session.query.assert_not_called() diff --git a/api/tests/unit_tests/services/test_model_load_balancing_service.py b/api/tests/unit_tests/services/test_model_load_balancing_service.py index 49e572584b..b43e79dff5 100644 --- a/api/tests/unit_tests/services/test_model_load_balancing_service.py +++ b/api/tests/unit_tests/services/test_model_load_balancing_service.py @@ -6,18 +6,18 @@ from typing import Any, cast from unittest.mock import MagicMock import pytest -from pytest_mock import MockerFixture - -from constants import HIDDEN_VALUE -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ( +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ( CredentialFormSchema, FieldModelSchema, FormType, ModelCredentialSchema, ProviderCredentialSchema, ) +from pytest_mock import MockerFixture + +from constants import HIDDEN_VALUE from models.provider import LoadBalancingModelConfig from services.model_load_balancing_service import ModelLoadBalancingService @@ -69,9 +69,13 @@ def _load_balancing_model_config(**kwargs: Any) -> LoadBalancingModelConfig: def service(mocker: MockerFixture) -> ModelLoadBalancingService: # Arrange provider_manager = MagicMock() - mocker.patch("services.model_load_balancing_service.ProviderManager", return_value=provider_manager) + mocker.patch("services.model_load_balancing_service.create_plugin_provider_manager", return_value=provider_manager) + model_assembly = SimpleNamespace(provider_manager=provider_manager, model_provider_factory=MagicMock()) + mocker.patch("services.model_load_balancing_service.create_plugin_model_assembly", return_value=model_assembly) svc = ModelLoadBalancingService() svc.provider_manager = provider_manager + svc.model_assembly = model_assembly + svc._get_provider_manager = lambda _tenant_id: provider_manager # type: ignore[method-assign] return svc @@ -666,6 +670,9 @@ def test_validate_load_balancing_credentials_should_delegate_to_custom_validate_ assert mock_validate.call_count == 2 assert mock_validate.call_args_list[0].kwargs["load_balancing_model_config"] is existing_config assert mock_validate.call_args_list[1].kwargs["load_balancing_model_config"] is None + shared_model_provider_factory = service.model_assembly.model_provider_factory + assert mock_validate.call_args_list[0].kwargs["model_provider_factory"] is shared_model_provider_factory + assert mock_validate.call_args_list[1].kwargs["model_provider_factory"] is shared_model_provider_factory def test_custom_credentials_validate_should_replace_hidden_secret_with_original_value_and_encrypt( @@ -708,7 +715,6 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val load_balancing_model_config = _load_balancing_model_config(encrypted_config="not-json") mock_factory = MagicMock() mock_factory.model_credentials_validate.return_value = {"api_key": "validated"} - mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) mock_encrypt = mocker.patch( "services.model_load_balancing_service.encrypter.encrypt_token", side_effect=lambda tenant_id, value: f"enc:{value}", @@ -722,6 +728,7 @@ def test_custom_credentials_validate_should_handle_invalid_original_json_and_val model="gpt-4o-mini", credentials={"api_key": "plain"}, load_balancing_model_config=load_balancing_model_config, + model_provider_factory=mock_factory, validate=True, ) @@ -740,7 +747,6 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m provider_configuration = _build_provider_configuration(provider_schema=_build_provider_credential_schema()) mock_factory = MagicMock() mock_factory.provider_credentials_validate.return_value = {"api_key": "provider-validated"} - mocker.patch("services.model_load_balancing_service.ModelProviderFactory", return_value=mock_factory) mocker.patch( "services.model_load_balancing_service.encrypter.encrypt_token", side_effect=lambda tenant_id, value: f"enc:{value}", @@ -753,6 +759,7 @@ def test_custom_credentials_validate_should_validate_with_provider_schema_when_m model_type=ModelType.LLM, model="gpt-4o-mini", credentials={"api_key": "plain"}, + model_provider_factory=mock_factory, validate=True, ) diff --git a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py index 6a6b63f003..1bd979b9ec 100644 --- a/api/tests/unit_tests/services/test_model_provider_service_sanitization.py +++ b/api/tests/unit_tests/services/test_model_provider_service_sanitization.py @@ -1,11 +1,11 @@ import types import pytest +from graphon.model_runtime.entities.common_entities import I18nObject +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.model_runtime.entities.provider_entities import ConfigurateMethod from core.entities.provider_entities import CredentialConfiguration, CustomModelConfiguration -from dify_graph.model_runtime.entities.common_entities import I18nObject -from dify_graph.model_runtime.entities.model_entities import ModelType -from dify_graph.model_runtime.entities.provider_entities import ConfigurateMethod from models.provider import ProviderType from services.model_provider_service import ModelProviderService @@ -71,7 +71,7 @@ def service_with_fake_configurations(): return _FakeConfigurations(fake_provider_configuration) svc = ModelProviderService() - svc.provider_manager = _FakeProviderManager() + svc._get_provider_manager = lambda tenant_id: _FakeProviderManager() return svc diff --git a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py b/api/tests/unit_tests/services/test_restore_archived_workflow_run.py deleted file mode 100644 index a214ecf728..0000000000 --- a/api/tests/unit_tests/services/test_restore_archived_workflow_run.py +++ /dev/null @@ -1,37 +0,0 @@ -""" -Unit tests for workflow run restore functionality. -""" - -from datetime import datetime - - -class TestWorkflowRunRestore: - """Tests for the WorkflowRunRestore class.""" - - def test_restore_initialization(self): - """Restore service should respect dry_run flag.""" - from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore - - restore = WorkflowRunRestore(dry_run=True) - - assert restore.dry_run is True - - def test_convert_datetime_fields(self): - """ISO datetime strings should be converted to datetime objects.""" - from models.workflow import WorkflowRun - from services.retention.workflow_run.restore_archived_workflow_run import WorkflowRunRestore - - record = { - "id": "test-id", - "created_at": "2024-01-01T12:00:00", - "finished_at": "2024-01-01T12:05:00", - "name": "test", - } - - restore = WorkflowRunRestore() - result = restore._convert_datetime_fields(record, WorkflowRun) - - assert isinstance(result["created_at"], datetime) - assert result["created_at"].year == 2024 - assert result["created_at"].month == 1 - assert result["name"] == "test" diff --git a/api/tests/unit_tests/services/test_summary_index_service.py b/api/tests/unit_tests/services/test_summary_index_service.py index ef53df9350..cbf3e121d8 100644 --- a/api/tests/unit_tests/services/test_summary_index_service.py +++ b/api/tests/unit_tests/services/test_summary_index_service.py @@ -191,7 +191,7 @@ def test_vectorize_summary_retries_connection_errors_then_succeeds(monkeypatch: embedding_model.get_text_embedding_num_tokens.return_value = [5] model_manager = MagicMock() model_manager.get_model_instance.return_value = embedding_model - monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager)) vector_instance = MagicMock() vector_instance.add_texts.side_effect = [RuntimeError("connection timeout"), None] @@ -230,7 +230,7 @@ def test_vectorize_summary_without_session_creates_record_when_missing(monkeypat model_manager = MagicMock() model_manager.get_model_instance.side_effect = RuntimeError("no model") - monkeypatch.setattr(summary_module, "ModelManager", MagicMock(return_value=model_manager)) + monkeypatch.setattr(summary_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager)) # New session used after vectorization succeeds (record not found by id nor chunk_id). session = MagicMock(name="session") @@ -407,8 +407,8 @@ def test_vectorize_summary_updates_existing_record_found_by_chunk_id(monkeypatch vector_instance.add_texts.return_value = None monkeypatch.setattr(summary_module, "Vector", MagicMock(return_value=vector_instance)) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -441,8 +441,8 @@ def test_vectorize_summary_updates_existing_record_found_by_id(monkeypatch: pyte summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -474,8 +474,8 @@ def test_vectorize_summary_session_enter_returns_none_triggers_runtime_error(mon summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) @@ -510,8 +510,8 @@ def test_vectorize_summary_created_record_becomes_none_triggers_guard(monkeypatc summary_module, "Vector", MagicMock(return_value=MagicMock(add_texts=MagicMock(return_value=None))) ) monkeypatch.setattr( - summary_module, - "ModelManager", + summary_module.ModelManager, + "for_tenant", MagicMock(return_value=MagicMock(get_model_instance=MagicMock(return_value=None))), ) diff --git a/api/tests/unit_tests/services/test_tag_service.py b/api/tests/unit_tests/services/test_tag_service.py deleted file mode 100644 index b09463b1bc..0000000000 --- a/api/tests/unit_tests/services/test_tag_service.py +++ /dev/null @@ -1,1336 +0,0 @@ -""" -Comprehensive unit tests for TagService. - -This test suite provides complete coverage of tag management operations in Dify, -following TDD principles with the Arrange-Act-Assert pattern. - -The TagService is responsible for managing tags that can be associated with -datasets (knowledge bases) and applications. Tags enable users to organize, -categorize, and filter their content effectively. - -## Test Coverage - -### 1. Tag Retrieval (TestTagServiceRetrieval) -Tests tag listing and filtering: -- Get tags with binding counts -- Filter tags by keyword (case-insensitive) -- Get tags by target ID (apps/datasets) -- Get tags by tag name -- Get target IDs by tag IDs -- Empty results handling - -### 2. Tag CRUD Operations (TestTagServiceCRUD) -Tests tag creation, update, and deletion: -- Create new tags -- Prevent duplicate tag names -- Update tag names -- Update with duplicate name validation -- Delete tags and cascade delete bindings -- Get tag binding counts -- NotFound error handling - -### 3. Tag Binding Operations (TestTagServiceBindings) -Tests tag-to-resource associations: -- Save tag bindings (apps/datasets) -- Prevent duplicate bindings (idempotent) -- Delete tag bindings -- Check target exists validation -- Batch binding operations - -## Testing Approach - -- **Mocking Strategy**: All external dependencies (database, current_user) are mocked - for fast, isolated unit tests -- **Factory Pattern**: TagServiceTestDataFactory provides consistent test data -- **Fixtures**: Mock objects are configured per test method -- **Assertions**: Each test verifies return values and side effects - (database operations, method calls) - -## Key Concepts - -**Tag Types:** -- knowledge: Tags for datasets/knowledge bases -- app: Tags for applications - -**Tag Bindings:** -- Many-to-many relationship between tags and resources -- Each binding links a tag to a specific app or dataset -- Bindings are tenant-scoped for multi-tenancy - -**Validation:** -- Tag names must be unique within tenant and type -- Target resources must exist before binding -- Cascade deletion of bindings when tag is deleted -""" - - -# ============================================================================ -# IMPORTS -# ============================================================================ - -from datetime import UTC, datetime -from unittest.mock import MagicMock, Mock, create_autospec, patch - -import pytest -from werkzeug.exceptions import NotFound - -from models.dataset import Dataset -from models.enums import TagType -from models.model import App, Tag, TagBinding -from services.tag_service import TagService - -# ============================================================================ -# TEST DATA FACTORY -# ============================================================================ - - -class TagServiceTestDataFactory: - """ - Factory for creating test data and mock objects. - - Provides reusable methods to create consistent mock objects for testing - tag-related operations. This factory ensures all test data follows the - same structure and reduces code duplication across tests. - - The factory pattern is used here to: - - Ensure consistent test data creation - - Reduce boilerplate code in individual tests - - Make tests more maintainable and readable - - Centralize mock object configuration - """ - - @staticmethod - def create_tag_mock( - tag_id: str = "tag-123", - name: str = "Test Tag", - tag_type: TagType = TagType.APP, - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """ - Create a mock Tag object. - - This method creates a mock Tag instance with all required attributes - set to sensible defaults. Additional attributes can be passed via - kwargs to customize the mock for specific test scenarios. - - Args: - tag_id: Unique identifier for the tag - name: Tag name (e.g., "Frontend", "Backend", "Data Science") - tag_type: Type of tag ('app' or 'knowledge') - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - (e.g., created_by, created_at, etc.) - - Returns: - Mock Tag object with specified attributes - - Example: - >>> tag = factory.create_tag_mock( - ... tag_id="tag-456", - ... name="Machine Learning", - ... tag_type="knowledge" - ... ) - """ - # Create a mock that matches the Tag model interface - tag = create_autospec(Tag, instance=True) - - # Set core attributes - tag.id = tag_id - tag.name = name - tag.type = tag_type - tag.tenant_id = tenant_id - - # Set default optional attributes - tag.created_by = kwargs.pop("created_by", "user-123") - tag.created_at = kwargs.pop("created_at", datetime(2023, 1, 1, 0, 0, 0, tzinfo=UTC)) - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(tag, key, value) - - return tag - - @staticmethod - def create_tag_binding_mock( - binding_id: str = "binding-123", - tag_id: str = "tag-123", - target_id: str = "target-123", - tenant_id: str = "tenant-123", - **kwargs, - ) -> Mock: - """ - Create a mock TagBinding object. - - TagBindings represent the many-to-many relationship between tags - and resources (datasets or apps). This method creates a mock - binding with the necessary attributes. - - Args: - binding_id: Unique identifier for the binding - tag_id: Associated tag identifier - target_id: Associated target (app/dataset) identifier - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - (e.g., created_by, etc.) - - Returns: - Mock TagBinding object with specified attributes - - Example: - >>> binding = factory.create_tag_binding_mock( - ... tag_id="tag-456", - ... target_id="dataset-789", - ... tenant_id="tenant-123" - ... ) - """ - # Create a mock that matches the TagBinding model interface - binding = create_autospec(TagBinding, instance=True) - - # Set core attributes - binding.id = binding_id - binding.tag_id = tag_id - binding.target_id = target_id - binding.tenant_id = tenant_id - - # Set default optional attributes - binding.created_by = kwargs.pop("created_by", "user-123") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(binding, key, value) - - return binding - - @staticmethod - def create_app_mock(app_id: str = "app-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock App object. - - This method creates a mock App instance for testing tag bindings - to applications. Apps are one of the two target types that tags - can be bound to (the other being datasets/knowledge bases). - - Args: - app_id: Unique identifier for the app - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - - Returns: - Mock App object with specified attributes - - Example: - >>> app = factory.create_app_mock( - ... app_id="app-456", - ... name="My Chat App" - ... ) - """ - # Create a mock that matches the App model interface - app = create_autospec(App, instance=True) - - # Set core attributes - app.id = app_id - app.tenant_id = tenant_id - app.name = kwargs.get("name", "Test App") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(app, key, value) - - return app - - @staticmethod - def create_dataset_mock(dataset_id: str = "dataset-123", tenant_id: str = "tenant-123", **kwargs) -> Mock: - """ - Create a mock Dataset object. - - This method creates a mock Dataset instance for testing tag bindings - to knowledge bases. Datasets (knowledge bases) are one of the two - target types that tags can be bound to (the other being apps). - - Args: - dataset_id: Unique identifier for the dataset - tenant_id: Tenant identifier for multi-tenancy isolation - **kwargs: Additional attributes to set on the mock - - Returns: - Mock Dataset object with specified attributes - - Example: - >>> dataset = factory.create_dataset_mock( - ... dataset_id="dataset-456", - ... name="My Knowledge Base" - ... ) - """ - # Create a mock that matches the Dataset model interface - dataset = create_autospec(Dataset, instance=True) - - # Set core attributes - dataset.id = dataset_id - dataset.tenant_id = tenant_id - dataset.name = kwargs.pop("name", "Test Dataset") - - # Apply any additional attributes from kwargs - for key, value in kwargs.items(): - setattr(dataset, key, value) - - return dataset - - -# ============================================================================ -# PYTEST FIXTURES -# ============================================================================ - - -@pytest.fixture -def factory(): - """ - Provide the test data factory to all tests. - - This fixture makes the TagServiceTestDataFactory available to all test - methods, allowing them to create consistent mock objects easily. - - Returns: - TagServiceTestDataFactory class - """ - return TagServiceTestDataFactory - - -# ============================================================================ -# TAG RETRIEVAL TESTS -# ============================================================================ - - -class TestTagServiceRetrieval: - """ - Test tag retrieval operations. - - This test class covers all methods related to retrieving and querying - tags from the system. These operations are read-only and do not modify - the database state. - - Methods tested: - - get_tags: Retrieve tags with optional keyword filtering - - get_target_ids_by_tag_ids: Get target IDs (datasets/apps) by tag IDs - - get_tag_by_tag_name: Find tags by exact name match - - get_tags_by_target_id: Get all tags bound to a specific target - """ - - @patch("services.tag_service.db.session") - def test_get_tags_with_binding_counts(self, mock_db_session, factory): - """ - Test retrieving tags with their binding counts. - - This test verifies that the get_tags method correctly retrieves - a list of tags along with the count of how many resources - (datasets/apps) are bound to each tag. - - The method should: - - Query tags filtered by type and tenant - - Include binding counts via a LEFT OUTER JOIN - - Return results ordered by creation date (newest first) - - Expected behavior: - - Returns a list of tuples containing (id, type, name, binding_count) - - Each tag includes its binding count - - Results are ordered by creation date descending - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - - # Mock query results: tuples of (tag_id, type, name, binding_count) - # This simulates the SQL query result with aggregated binding counts - mock_results = [ - ("tag-1", "app", "Frontend", 5), # Frontend tag with 5 bindings - ("tag-2", "app", "Backend", 3), # Backend tag with 3 bindings - ("tag-3", "app", "API", 0), # API tag with no bindings - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.outerjoin.return_value = mock_query # LEFT OUTER JOIN with TagBinding - mock_query.where.return_value = mock_query # WHERE clause for filtering - mock_query.group_by.return_value = mock_query # GROUP BY for aggregation - mock_query.order_by.return_value = mock_query # ORDER BY for sorting - mock_query.all.return_value = mock_results # Final result - - # Act - # Execute the method under test - results = TagService.get_tags(tag_type=tag_type, current_tenant_id=tenant_id) - - # Assert - # Verify the results match expectations - assert len(results) == 3, "Should return 3 tags" - - # Verify each tag's data structure - assert results[0] == ("tag-1", "app", "Frontend", 5), "First tag should match" - assert results[1] == ("tag-2", "app", "Backend", 3), "Second tag should match" - assert results[2] == ("tag-3", "app", "API", 0), "Third tag should match" - - # Verify database query was called - mock_db_session.query.assert_called_once() - - @patch("services.tag_service.db.session") - def test_get_tags_with_keyword_filter(self, mock_db_session, factory): - """ - Test retrieving tags filtered by keyword (case-insensitive). - - This test verifies that the get_tags method correctly filters tags - by keyword when a keyword parameter is provided. The filtering - should be case-insensitive and support partial matches. - - The method should: - - Apply an additional WHERE clause when keyword is provided - - Use ILIKE for case-insensitive pattern matching - - Support partial matches (e.g., "data" matches "Database" and "Data Science") - - Expected behavior: - - Returns only tags whose names contain the keyword - - Matching is case-insensitive - - Partial matches are supported - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "knowledge" - keyword = "data" - - # Mock query results filtered by keyword - mock_results = [ - ("tag-1", "knowledge", "Database", 2), - ("tag-2", "knowledge", "Data Science", 4), - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.outerjoin.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.group_by.return_value = mock_query - mock_query.order_by.return_value = mock_query - mock_query.all.return_value = mock_results - - # Act - # Execute the method with keyword filter - results = TagService.get_tags(tag_type=tag_type, current_tenant_id=tenant_id, keyword=keyword) - - # Assert - # Verify filtered results - assert len(results) == 2, "Should return 2 matching tags" - - # Verify keyword filter was applied - # The where() method should be called at least twice: - # 1. Initial WHERE clause for type and tenant - # 2. Additional WHERE clause for keyword filtering - assert mock_query.where.call_count >= 2, "Keyword filter should add WHERE clause" - - @patch("services.tag_service.db.session") - def test_get_target_ids_by_tag_ids(self, mock_db_session, factory): - """ - Test retrieving target IDs by tag IDs. - - This test verifies that the get_target_ids_by_tag_ids method correctly - retrieves all target IDs (dataset/app IDs) that are bound to the - specified tags. This is useful for filtering datasets or apps by tags. - - The method should: - - First validate and filter tags by type and tenant - - Then find all bindings for those tags - - Return the target IDs from those bindings - - Expected behavior: - - Returns a list of target IDs (strings) - - Only includes targets bound to valid tags - - Respects tenant and type filtering - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - tag_ids = ["tag-1", "tag-2"] - - # Create mock tag objects - tags = [ - factory.create_tag_mock(tag_id="tag-1", tenant_id=tenant_id, tag_type=tag_type), - factory.create_tag_mock(tag_id="tag-2", tenant_id=tenant_id, tag_type=tag_type), - ] - - # Mock target IDs that are bound to these tags - target_ids = ["app-1", "app-2", "app-3"] - - # Mock tag query (first scalars call) - mock_scalars_tags = MagicMock() - mock_scalars_tags.all.return_value = tags - - # Mock binding query (second scalars call) - mock_scalars_bindings = MagicMock() - mock_scalars_bindings.all.return_value = target_ids - - # Configure side_effect to return different mocks for each scalars() call - mock_db_session.scalars.side_effect = [mock_scalars_tags, mock_scalars_bindings] - - # Act - # Execute the method under test - results = TagService.get_target_ids_by_tag_ids(tag_type=tag_type, current_tenant_id=tenant_id, tag_ids=tag_ids) - - # Assert - # Verify results match expected target IDs - assert results == target_ids, "Should return all target IDs bound to tags" - - # Verify both queries were executed - assert mock_db_session.scalars.call_count == 2, "Should execute tag query and binding query" - - @patch("services.tag_service.db.session") - def test_get_target_ids_with_empty_tag_ids(self, mock_db_session, factory): - """ - Test that empty tag_ids returns empty list. - - This test verifies the edge case handling when an empty list of - tag IDs is provided. The method should return early without - executing any database queries. - - Expected behavior: - - Returns empty list immediately - - Does not execute any database queries - - Handles empty input gracefully - """ - # Arrange - # Set up test parameters with empty tag IDs - tenant_id = "tenant-123" - tag_type = "app" - - # Act - # Execute the method with empty tag IDs list - results = TagService.get_target_ids_by_tag_ids(tag_type=tag_type, current_tenant_id=tenant_id, tag_ids=[]) - - # Assert - # Verify empty result and no database queries - assert results == [], "Should return empty list for empty input" - mock_db_session.scalars.assert_not_called(), "Should not query database for empty input" - - @patch("services.tag_service.db.session") - def test_get_tag_by_tag_name(self, mock_db_session, factory): - """ - Test retrieving tags by name. - - This test verifies that the get_tag_by_tag_name method correctly - finds tags by their exact name. This is used for duplicate name - checking and tag lookup operations. - - The method should: - - Perform exact name matching (case-sensitive) - - Filter by type and tenant - - Return a list of matching tags (usually 0 or 1) - - Expected behavior: - - Returns list of tags with matching name - - Respects type and tenant filtering - - Returns empty list if no matches found - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - tag_name = "Production" - - # Create mock tag with matching name - tags = [factory.create_tag_mock(name=tag_name, tag_type=tag_type, tenant_id=tenant_id)] - - # Configure mock database session - mock_scalars = MagicMock() - mock_scalars.all.return_value = tags - mock_db_session.scalars.return_value = mock_scalars - - # Act - # Execute the method under test - results = TagService.get_tag_by_tag_name(tag_type=tag_type, current_tenant_id=tenant_id, tag_name=tag_name) - - # Assert - # Verify tag was found - assert len(results) == 1, "Should find exactly one tag" - assert results[0].name == tag_name, "Tag name should match" - - @patch("services.tag_service.db.session") - def test_get_tag_by_tag_name_returns_empty_for_missing_params(self, mock_db_session, factory): - """ - Test that missing tag_type or tag_name returns empty list. - - This test verifies the input validation for the get_tag_by_tag_name - method. When either tag_type or tag_name is empty or missing, - the method should return early without querying the database. - - Expected behavior: - - Returns empty list for empty tag_type - - Returns empty list for empty tag_name - - Does not execute database queries for invalid input - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - - # Act & Assert - # Test with empty tag_type - assert TagService.get_tag_by_tag_name("", tenant_id, "name") == [], "Should return empty for empty type" - - # Test with empty tag_name - assert TagService.get_tag_by_tag_name("app", tenant_id, "") == [], "Should return empty for empty name" - - # Verify no database queries were executed - mock_db_session.scalars.assert_not_called(), "Should not query database for invalid input" - - @patch("services.tag_service.db.session") - def test_get_tags_by_target_id(self, mock_db_session, factory): - """ - Test retrieving tags associated with a specific target. - - This test verifies that the get_tags_by_target_id method correctly - retrieves all tags that are bound to a specific target (dataset or app). - This is useful for displaying tags associated with a resource. - - The method should: - - Join Tag and TagBinding tables - - Filter by target_id, tenant, and type - - Return all tags bound to the target - - Expected behavior: - - Returns list of Tag objects bound to the target - - Respects tenant and type filtering - - Returns empty list if no tags are bound - """ - # Arrange - # Set up test parameters - tenant_id = "tenant-123" - tag_type = "app" - target_id = "app-123" - - # Create mock tags that are bound to the target - tags = [ - factory.create_tag_mock(tag_id="tag-1", name="Frontend"), - factory.create_tag_mock(tag_id="tag-2", name="Production"), - ] - - # Configure mock database session and query chain - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.join.return_value = mock_query # JOIN with TagBinding - mock_query.where.return_value = mock_query # WHERE clause for filtering - mock_query.all.return_value = tags # Final result - - # Act - # Execute the method under test - results = TagService.get_tags_by_target_id(tag_type=tag_type, current_tenant_id=tenant_id, target_id=target_id) - - # Assert - # Verify tags were retrieved - assert len(results) == 2, "Should return 2 tags bound to target" - - # Verify tag names - assert results[0].name == "Frontend", "First tag name should match" - assert results[1].name == "Production", "Second tag name should match" - - -# ============================================================================ -# TAG CRUD OPERATIONS TESTS -# ============================================================================ - - -class TestTagServiceCRUD: - """ - Test tag CRUD operations. - - This test class covers all Create, Read, Update, and Delete operations - for tags. These operations modify the database state and require proper - transaction handling and validation. - - Methods tested: - - save_tags: Create new tags - - update_tags: Update existing tag names - - delete_tag: Delete tags and cascade delete bindings - - get_tag_binding_count: Get count of bindings for a tag - """ - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - @patch("services.tag_service.uuid.uuid4", autospec=True) - def test_save_tags(self, mock_uuid, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): - """ - Test creating a new tag. - - This test verifies that the save_tags method correctly creates a new - tag in the database with all required attributes. The method should - validate uniqueness, generate a UUID, and persist the tag. - - The method should: - - Check for duplicate tag names (via get_tag_by_tag_name) - - Generate a unique UUID for the tag ID - - Set user and tenant information from current_user - - Persist the tag to the database - - Commit the transaction - - Expected behavior: - - Creates tag with correct attributes - - Assigns UUID to tag ID - - Sets created_by from current_user - - Sets tenant_id from current_user - - Commits to database - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Mock UUID generation - mock_uuid.return_value = "new-tag-id" - - # Mock no existing tag (duplicate check passes) - mock_get_tag_by_name.return_value = [] - - # Prepare tag creation arguments - args = {"name": "New Tag", "type": "app"} - - # Act - # Execute the method under test - result = TagService.save_tags(args) - - # Assert - # Verify tag was added to database session - mock_db_session.add.assert_called_once(), "Should add tag to session" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - # Verify tag attributes - added_tag = mock_db_session.add.call_args[0][0] - assert added_tag.name == "New Tag", "Tag name should match" - assert added_tag.type == TagType.APP, "Tag type should match" - assert added_tag.created_by == "user-123", "Created by should match current user" - assert added_tag.tenant_id == "tenant-123", "Tenant ID should match current tenant" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - def test_save_tags_raises_error_for_duplicate_name(self, mock_get_tag_by_name, mock_current_user, factory): - """ - Test that creating a tag with duplicate name raises ValueError. - - This test verifies that the save_tags method correctly prevents - duplicate tag names within the same tenant and type. Tag names - must be unique per tenant and type combination. - - Expected behavior: - - Raises ValueError when duplicate name is detected - - Error message indicates "Tag name already exists" - - Does not create the tag - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing tag with same name (duplicate detected) - existing_tag = factory.create_tag_mock(name="Existing Tag") - mock_get_tag_by_name.return_value = [existing_tag] - - # Prepare tag creation arguments with duplicate name - args = {"name": "Existing Tag", "type": "app"} - - # Act & Assert - # Verify ValueError is raised for duplicate name - with pytest.raises(ValueError, match="Tag name already exists"): - TagService.save_tags(args) - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - def test_update_tags(self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory): - """ - Test updating a tag name. - - This test verifies that the update_tags method correctly updates - an existing tag's name while preserving other attributes. The method - should validate uniqueness of the new name and ensure the tag exists. - - The method should: - - Check for duplicate tag names (excluding the current tag) - - Find the tag by ID - - Update the tag name - - Commit the transaction - - Expected behavior: - - Updates tag name successfully - - Preserves other tag attributes - - Commits to database - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock no duplicate name (update check passes) - mock_get_tag_by_name.return_value = [] - - # Create mock tag to be updated - tag = factory.create_tag_mock(tag_id="tag-123", name="Old Name") - - # Configure mock database session to return the tag - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = tag - - # Prepare update arguments - args = {"name": "New Name", "type": "app"} - - # Act - # Execute the method under test - result = TagService.update_tags(args, tag_id="tag-123") - - # Assert - # Verify tag name was updated - assert tag.name == "New Name", "Tag name should be updated" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.get_tag_by_tag_name", autospec=True) - @patch("services.tag_service.db.session") - def test_update_tags_raises_error_for_duplicate_name( - self, mock_db_session, mock_get_tag_by_name, mock_current_user, factory - ): - """ - Test that updating to a duplicate name raises ValueError. - - This test verifies that the update_tags method correctly prevents - updating a tag to a name that already exists for another tag - within the same tenant and type. - - Expected behavior: - - Raises ValueError when duplicate name is detected - - Error message indicates "Tag name already exists" - - Does not update the tag - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing tag with the duplicate name - existing_tag = factory.create_tag_mock(name="Duplicate Name") - mock_get_tag_by_name.return_value = [existing_tag] - - # Prepare update arguments with duplicate name - args = {"name": "Duplicate Name", "type": "app"} - - # Act & Assert - # Verify ValueError is raised for duplicate name - with pytest.raises(ValueError, match="Tag name already exists"): - TagService.update_tags(args, tag_id="tag-123") - - @patch("services.tag_service.db.session") - def test_update_tags_raises_not_found_for_missing_tag(self, mock_db_session, factory): - """ - Test that updating a non-existent tag raises NotFound. - - This test verifies that the update_tags method correctly handles - the case when attempting to update a tag that does not exist. - This prevents silent failures and provides clear error feedback. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Tag not found" - - Does not attempt to update or commit - """ - # Arrange - # Configure mock database session to return None (tag not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Mock duplicate check and current_user - with patch("services.tag_service.TagService.get_tag_by_tag_name", return_value=[], autospec=True): - with patch("services.tag_service.current_user", autospec=True) as mock_user: - mock_user.current_tenant_id = "tenant-123" - args = {"name": "New Name", "type": "app"} - - # Act & Assert - # Verify NotFound is raised for non-existent tag - with pytest.raises(NotFound, match="Tag not found"): - TagService.update_tags(args, tag_id="nonexistent") - - @patch("services.tag_service.db.session") - def test_get_tag_binding_count(self, mock_db_session, factory): - """ - Test getting the count of bindings for a tag. - - This test verifies that the get_tag_binding_count method correctly - counts how many resources (datasets/apps) are bound to a specific tag. - This is useful for displaying tag usage statistics. - - The method should: - - Query TagBinding table filtered by tag_id - - Return the count of matching bindings - - Expected behavior: - - Returns integer count of bindings - - Returns 0 for tags with no bindings - """ - # Arrange - # Set up test parameters - tag_id = "tag-123" - expected_count = 5 - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.count.return_value = expected_count - - # Act - # Execute the method under test - result = TagService.get_tag_binding_count(tag_id) - - # Assert - # Verify count matches expectation - assert result == expected_count, "Binding count should match" - - @patch("services.tag_service.db.session") - def test_delete_tag(self, mock_db_session, factory): - """ - Test deleting a tag and its bindings. - - This test verifies that the delete_tag method correctly deletes - a tag along with all its associated bindings (cascade delete). - This ensures data integrity and prevents orphaned bindings. - - The method should: - - Find the tag by ID - - Delete the tag - - Find all bindings for the tag - - Delete all bindings (cascade delete) - - Commit the transaction - - Expected behavior: - - Deletes tag from database - - Deletes all associated bindings - - Commits transaction - """ - # Arrange - # Set up test parameters - tag_id = "tag-123" - - # Create mock tag to be deleted - tag = factory.create_tag_mock(tag_id=tag_id) - - # Create mock bindings that will be cascade deleted - bindings = [factory.create_tag_binding_mock(binding_id=f"binding-{i}", tag_id=tag_id) for i in range(3)] - - # Configure mock database session for tag query - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = tag - - # Configure mock database session for bindings query - mock_scalars = MagicMock() - mock_scalars.all.return_value = bindings - mock_db_session.scalars.return_value = mock_scalars - - # Act - # Execute the method under test - TagService.delete_tag(tag_id) - - # Assert - # Verify tag and bindings were deleted - mock_db_session.delete.assert_called(), "Should call delete method" - - # Verify delete was called 4 times (1 tag + 3 bindings) - assert mock_db_session.delete.call_count == 4, "Should delete tag and all bindings" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.db.session") - def test_delete_tag_raises_not_found(self, mock_db_session, factory): - """ - Test that deleting a non-existent tag raises NotFound. - - This test verifies that the delete_tag method correctly handles - the case when attempting to delete a tag that does not exist. - This prevents silent failures and provides clear error feedback. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Tag not found" - - Does not attempt to delete or commit - """ - # Arrange - # Configure mock database session to return None (tag not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None - - # Act & Assert - # Verify NotFound is raised for non-existent tag - with pytest.raises(NotFound, match="Tag not found"): - TagService.delete_tag("nonexistent") - - -# ============================================================================ -# TAG BINDING OPERATIONS TESTS -# ============================================================================ - - -class TestTagServiceBindings: - """ - Test tag binding operations. - - This test class covers all operations related to binding tags to - resources (datasets and apps). Tag bindings create the many-to-many - relationship between tags and resources. - - Methods tested: - - save_tag_binding: Create bindings between tags and targets - - delete_tag_binding: Remove bindings between tags and targets - - check_target_exists: Validate target (dataset/app) existence - """ - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_save_tag_binding(self, mock_db_session, mock_check_target, mock_current_user, factory): - """ - Test creating tag bindings. - - This test verifies that the save_tag_binding method correctly - creates bindings between tags and a target resource (dataset or app). - The method supports batch binding of multiple tags to a single target. - - The method should: - - Validate target exists (via check_target_exists) - - Check for existing bindings to avoid duplicates - - Create new bindings for tags that aren't already bound - - Commit the transaction - - Expected behavior: - - Validates target exists - - Creates bindings for each tag in tag_ids - - Skips tags that are already bound (idempotent) - - Commits transaction - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (no existing bindings) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # No existing bindings - - # Prepare binding arguments (batch binding) - args = {"type": "app", "target_id": "app-123", "tag_ids": ["tag-1", "tag-2"]} - - # Act - # Execute the method under test - TagService.save_tag_binding(args) - - # Assert - # Verify target existence was checked - mock_check_target.assert_called_once_with("app", "app-123"), "Should validate target exists" - - # Verify bindings were created (2 bindings for 2 tags) - assert mock_db_session.add.call_count == 2, "Should create 2 bindings" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_save_tag_binding_is_idempotent(self, mock_db_session, mock_check_target, mock_current_user, factory): - """ - Test that saving duplicate bindings is idempotent. - - This test verifies that the save_tag_binding method correctly handles - the case when attempting to create a binding that already exists. - The method should skip existing bindings and not create duplicates, - making the operation idempotent. - - Expected behavior: - - Checks for existing bindings - - Skips tags that are already bound - - Does not create duplicate bindings - - Still commits transaction - """ - # Arrange - # Configure mock current_user - mock_current_user.id = "user-123" - mock_current_user.current_tenant_id = "tenant-123" - - # Mock existing binding (duplicate detected) - existing_binding = factory.create_tag_binding_mock() - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = existing_binding # Binding already exists - - # Prepare binding arguments - args = {"type": "app", "target_id": "app-123", "tag_ids": ["tag-1"]} - - # Act - # Execute the method under test - TagService.save_tag_binding(args) - - # Assert - # Verify no new binding was added (idempotent) - mock_db_session.add.assert_not_called(), "Should not create duplicate binding" - - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_delete_tag_binding(self, mock_db_session, mock_check_target, factory): - """ - Test deleting a tag binding. - - This test verifies that the delete_tag_binding method correctly - removes a binding between a tag and a target resource. This - operation should be safe even if the binding doesn't exist. - - The method should: - - Validate target exists (via check_target_exists) - - Find the binding by tag_id and target_id - - Delete the binding if it exists - - Commit the transaction - - Expected behavior: - - Validates target exists - - Deletes the binding - - Commits transaction - """ - # Arrange - # Create mock binding to be deleted - binding = factory.create_tag_binding_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = binding - - # Prepare delete arguments - args = {"type": "app", "target_id": "app-123", "tag_id": "tag-1"} - - # Act - # Execute the method under test - TagService.delete_tag_binding(args) - - # Assert - # Verify target existence was checked - mock_check_target.assert_called_once_with("app", "app-123"), "Should validate target exists" - - # Verify binding was deleted - mock_db_session.delete.assert_called_once_with(binding), "Should delete the binding" - - # Verify transaction was committed - mock_db_session.commit.assert_called_once(), "Should commit transaction" - - @patch("services.tag_service.TagService.check_target_exists", autospec=True) - @patch("services.tag_service.db.session") - def test_delete_tag_binding_does_nothing_if_not_exists(self, mock_db_session, mock_check_target, factory): - """ - Test that deleting a non-existent binding is a no-op. - - This test verifies that the delete_tag_binding method correctly - handles the case when attempting to delete a binding that doesn't - exist. The method should not raise an error and should not commit - if there's nothing to delete. - - Expected behavior: - - Validates target exists - - Does not raise error for non-existent binding - - Does not call delete or commit if binding doesn't exist - """ - # Arrange - # Configure mock database session (binding not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Binding doesn't exist - - # Prepare delete arguments - args = {"type": "app", "target_id": "app-123", "tag_id": "tag-1"} - - # Act - # Execute the method under test - TagService.delete_tag_binding(args) - - # Assert - # Verify no delete operation was attempted - mock_db_session.delete.assert_not_called(), "Should not delete if binding doesn't exist" - - # Verify no commit was made (nothing changed) - mock_db_session.commit.assert_not_called(), "Should not commit if nothing to delete" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_for_dataset(self, mock_db_session, mock_current_user, factory): - """ - Test validating that a dataset target exists. - - This test verifies that the check_target_exists method correctly - validates the existence of a dataset (knowledge base) when the - target type is "knowledge". This validation ensures bindings - are only created for valid resources. - - The method should: - - Query Dataset table filtered by tenant and ID - - Raise NotFound if dataset doesn't exist - - Return normally if dataset exists - - Expected behavior: - - No exception raised when dataset exists - - Database query is executed - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Create mock dataset - dataset = factory.create_dataset_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = dataset # Dataset exists - - # Act - # Execute the method under test - TagService.check_target_exists("knowledge", "dataset-123") - - # Assert - # Verify no exception was raised and query was executed - mock_db_session.query.assert_called_once(), "Should query database for dataset" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_for_app(self, mock_db_session, mock_current_user, factory): - """ - Test validating that an app target exists. - - This test verifies that the check_target_exists method correctly - validates the existence of an application when the target type is - "app". This validation ensures bindings are only created for valid - resources. - - The method should: - - Query App table filtered by tenant and ID - - Raise NotFound if app doesn't exist - - Return normally if app exists - - Expected behavior: - - No exception raised when app exists - - Database query is executed - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Create mock app - app = factory.create_app_mock() - - # Configure mock database session - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = app # App exists - - # Act - # Execute the method under test - TagService.check_target_exists("app", "app-123") - - # Assert - # Verify no exception was raised and query was executed - mock_db_session.query.assert_called_once(), "Should query database for app" - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_raises_not_found_for_missing_dataset( - self, mock_db_session, mock_current_user, factory - ): - """ - Test that missing dataset raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when attempting to validate a dataset - that doesn't exist. This prevents creating bindings for invalid - resources. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Dataset not found" - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (dataset not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # Dataset doesn't exist - - # Act & Assert - # Verify NotFound is raised for non-existent dataset - with pytest.raises(NotFound, match="Dataset not found"): - TagService.check_target_exists("knowledge", "nonexistent") - - @patch("services.tag_service.current_user", autospec=True) - @patch("services.tag_service.db.session") - def test_check_target_exists_raises_not_found_for_missing_app(self, mock_db_session, mock_current_user, factory): - """ - Test that missing app raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when attempting to validate an app - that doesn't exist. This prevents creating bindings for invalid - resources. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "App not found" - """ - # Arrange - # Configure mock current_user - mock_current_user.current_tenant_id = "tenant-123" - - # Configure mock database session (app not found) - mock_query = MagicMock() - mock_db_session.query.return_value = mock_query - mock_query.where.return_value = mock_query - mock_query.first.return_value = None # App doesn't exist - - # Act & Assert - # Verify NotFound is raised for non-existent app - with pytest.raises(NotFound, match="App not found"): - TagService.check_target_exists("app", "nonexistent") - - def test_check_target_exists_raises_not_found_for_invalid_type(self, factory): - """ - Test that invalid binding type raises NotFound. - - This test verifies that the check_target_exists method correctly - raises a NotFound exception when an invalid target type is provided. - Only "knowledge" (for datasets) and "app" are valid target types. - - Expected behavior: - - Raises NotFound exception - - Error message indicates "Invalid binding type" - """ - # Act & Assert - # Verify NotFound is raised for invalid target type - with pytest.raises(NotFound, match="Invalid binding type"): - TagService.check_target_exists("invalid_type", "target-123") diff --git a/api/tests/unit_tests/services/test_variable_truncator.py b/api/tests/unit_tests/services/test_variable_truncator.py index c703ab64d0..9c23135225 100644 --- a/api/tests/unit_tests/services/test_variable_truncator.py +++ b/api/tests/unit_tests/services/test_variable_truncator.py @@ -16,10 +16,8 @@ from typing import Any from uuid import uuid4 import pytest - -from dify_graph.file.enums import FileTransferMethod, FileType -from dify_graph.file.models import File -from dify_graph.variables.segments import ( +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import ( ArrayFileSegment, ArrayNumberSegment, ArraySegment, @@ -30,6 +28,7 @@ from dify_graph.variables.segments import ( ObjectSegment, StringSegment, ) + from services.variable_truncator import ( DummyVariableTruncator, MaxDepthExceededError, diff --git a/api/tests/unit_tests/services/test_vector_service.py b/api/tests/unit_tests/services/test_vector_service.py index 16d3011810..598ff3fc3a 100644 --- a/api/tests/unit_tests/services/test_vector_service.py +++ b/api/tests/unit_tests/services/test_vector_service.py @@ -214,7 +214,9 @@ def test_create_segments_vector_parent_child_calls_generate_child_chunks_with_ex embedding_model_instance = MagicMock(name="embedding_model_instance") model_manager_instance = MagicMock(name="model_manager_instance") model_manager_instance.get_model_instance.return_value = embedding_model_instance - monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + monkeypatch.setattr( + vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance) + ) generate_child_chunks_mock = MagicMock() monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) @@ -262,7 +264,9 @@ def test_create_segments_vector_parent_child_uses_default_embedding_model_when_p embedding_model_instance = MagicMock() model_manager_instance = MagicMock() model_manager_instance.get_default_model_instance.return_value = embedding_model_instance - monkeypatch.setattr(vector_service_module, "ModelManager", MagicMock(return_value=model_manager_instance)) + monkeypatch.setattr( + vector_service_module.ModelManager, "for_tenant", MagicMock(return_value=model_manager_instance) + ) generate_child_chunks_mock = MagicMock() monkeypatch.setattr(VectorService, "generate_child_chunks", generate_child_chunks_mock) diff --git a/api/tests/unit_tests/services/test_workflow_run_service_pause.py b/api/tests/unit_tests/services/test_workflow_run_service_pause.py index 27664c7e29..a62c9f4555 100644 --- a/api/tests/unit_tests/services/test_workflow_run_service_pause.py +++ b/api/tests/unit_tests/services/test_workflow_run_service_pause.py @@ -13,10 +13,10 @@ from datetime import datetime from unittest.mock import MagicMock, create_autospec, patch import pytest +from graphon.enums import WorkflowExecutionStatus from sqlalchemy import Engine from sqlalchemy.orm import Session, sessionmaker -from dify_graph.enums import WorkflowExecutionStatus from models.workflow import WorkflowPause from repositories.api_workflow_run_repository import APIWorkflowRunRepository from repositories.sqlalchemy_api_workflow_run_repository import _PrivateWorkflowPauseEntity diff --git a/api/tests/unit_tests/services/test_workflow_service.py b/api/tests/unit_tests/services/test_workflow_service.py index d26c2f674f..cd71981bcf 100644 --- a/api/tests/unit_tests/services/test_workflow_service.py +++ b/api/tests/unit_tests/services/test_workflow_service.py @@ -12,22 +12,23 @@ This test suite covers: import json import uuid from typing import Any, cast -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pytest - -from dify_graph.entities import WorkflowNodeExecution -from dify_graph.enums import ( +from graphon.entities import WorkflowNodeExecution +from graphon.enums import ( BuiltinNodeTypes, ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus, ) -from dify_graph.errors import WorkflowNodeRunFailedError -from dify_graph.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent -from dify_graph.node_events import NodeRunResult -from dify_graph.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig -from dify_graph.variables.input_entities import VariableEntityType +from graphon.errors import WorkflowNodeRunFailedError +from graphon.graph_events import NodeRunFailedEvent, NodeRunSucceededEvent +from graphon.model_runtime.entities.model_entities import ModelType +from graphon.node_events import NodeRunResult +from graphon.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig +from graphon.variables.input_entities import VariableEntityType + from libs.datetime_utils import naive_utc_now from models.human_input import RecipientType from models.model import App, AppMode @@ -1545,7 +1546,10 @@ class TestWorkflowServiceCredentialValidation: def test_validate_llm_model_config_should_raise_value_error_on_failure(self, service: WorkflowService) -> None: """If ModelManager raises any exception it must be wrapped into ValueError.""" # Arrange - with patch("core.model_manager.ModelManager.get_model_instance", side_effect=RuntimeError("no key")): + assembly = MagicMock() + assembly.model_manager.get_model_instance.side_effect = RuntimeError("no key") + + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act + Assert with pytest.raises(ValueError, match="Failed to validate LLM model configuration"): service._validate_llm_model_config("tenant-1", "openai", "gpt-4") @@ -1558,30 +1562,30 @@ class TestWorkflowServiceCredentialValidation: mock_configs = MagicMock() mock_configs.get_models.return_value = [mock_model] + assembly = MagicMock() + assembly.provider_manager.get_configurations.return_value = mock_configs - with ( - patch("core.model_manager.ModelManager.get_model_instance"), - patch("core.provider_manager.ProviderManager") as mock_pm_cls, - ): - mock_pm_cls.return_value.get_configurations.return_value = mock_configs - + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act service._validate_llm_model_config("tenant-1", "openai", "gpt-4") # Assert mock_model.raise_for_status.assert_called_once() + assembly.model_manager.get_model_instance.assert_called_once_with( + tenant_id="tenant-1", + provider="openai", + model_type=ModelType.LLM, + model="gpt-4", + ) def test_validate_llm_model_config_model_not_found(self, service: WorkflowService) -> None: """Test ValueError when model is not found in provider configurations.""" mock_configs = MagicMock() mock_configs.get_models.return_value = [] # No models + assembly = MagicMock() + assembly.provider_manager.get_configurations.return_value = mock_configs - with ( - patch("core.model_manager.ModelManager.get_model_instance"), - patch("core.provider_manager.ProviderManager") as mock_pm_cls, - ): - mock_pm_cls.return_value.get_configurations.return_value = mock_configs - + with patch("services.workflow_service.create_plugin_model_assembly", return_value=assembly): # Act + Assert with pytest.raises(ValueError, match="Model gpt-4 not found for provider openai"): service._validate_llm_model_config("tenant-1", "openai", "gpt-4") @@ -2053,22 +2057,37 @@ class TestSetupVariablePool: workflow = self._make_workflow() # Act - with patch("services.workflow_service.VariablePool") as MockPool: + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.build_system_variables") as mock_build_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool, + ): _setup_variable_pool( query="hello", files=[], user_id="u-1", user_inputs={"k": "v"}, workflow=workflow, + node_id="start-node", node_type=BuiltinNodeTypes.START, conversation_id="conv-1", conversation_variables=[], ) - # Assert — VariablePool should be called with a SystemVariable (non-default) - MockPool.assert_called_once() - call_kwargs = MockPool.call_args.kwargs - assert call_kwargs["user_inputs"] == {"k": "v"} + # Assert — start nodes should build bootstrap variables and attach node inputs. + MockPool.assert_called_once_with() + mock_build_system_variables.assert_called_once() + mock_add_variables_to_pool.assert_called_once_with( + MockPool.return_value, + mock_build_bootstrap_variables.return_value, + ) + mock_add_node_inputs_to_pool.assert_called_once_with( + MockPool.return_value, + node_id="start-node", + inputs={"k": "v"}, + ) def test_setup_variable_pool_should_use_default_system_variables_for_non_start_node( self, @@ -2079,7 +2098,10 @@ class TestSetupVariablePool: # Act with ( patch("services.workflow_service.VariablePool") as MockPool, - patch("services.workflow_service.SystemVariable.default") as mock_default, + patch("services.workflow_service.default_system_variables") as mock_default_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, + patch("services.workflow_service.add_node_inputs_to_pool") as mock_add_node_inputs_to_pool, ): _setup_variable_pool( query="", @@ -2087,14 +2109,20 @@ class TestSetupVariablePool: user_id="u-1", user_inputs={}, workflow=workflow, + node_id="llm-node", node_type=BuiltinNodeTypes.LLM, # not a start/trigger node conversation_id="conv-1", conversation_variables=[], ) - # Assert — SystemVariable.default() should be used for non-start nodes - mock_default.assert_called_once() - MockPool.assert_called_once() + # Assert — default system variables should be used and node inputs should not be added. + mock_default_system_variables.assert_called_once() + MockPool.assert_called_once_with() + mock_add_variables_to_pool.assert_called_once_with( + MockPool.return_value, + mock_build_bootstrap_variables.return_value, + ) + mock_add_node_inputs_to_pool.assert_not_called() def test_setup_variable_pool_should_set_chatflow_specifics_for_non_workflow_type( self, @@ -2106,20 +2134,31 @@ class TestSetupVariablePool: workflow = self._make_workflow(workflow_type=WorkflowType.CHAT.value) # Act - with patch("services.workflow_service.VariablePool") as MockPool: + with ( + patch("services.workflow_service.VariablePool") as MockPool, + patch("services.workflow_service.build_system_variables") as mock_build_system_variables, + patch("services.workflow_service.build_bootstrap_variables"), + patch("services.workflow_service.add_variables_to_pool"), + patch("services.workflow_service.add_node_inputs_to_pool"), + ): _setup_variable_pool( query="what is AI?", files=[], user_id="u-1", user_inputs={}, workflow=workflow, + node_id="start-node", node_type=BuiltinNodeTypes.START, conversation_id="conv-abc", conversation_variables=[], ) - # Assert — we just verify VariablePool was called (chatflow path executed) - MockPool.assert_called_once() + # Assert — chatflow system variables should include query, conversation_id and dialogue_count. + MockPool.assert_called_once_with() + system_variable_values = mock_build_system_variables.call_args.args[0] + assert system_variable_values["query"] == "what is AI?" + assert system_variable_values["conversation_id"] == "conv-abc" + assert system_variable_values["dialogue_count"] == 1 class TestRebuildSingleFile: @@ -2142,7 +2181,7 @@ class TestRebuildSingleFile: # Assert assert result is mock_file - mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id) + mock_build.assert_called_once_with(mapping=value, tenant_id=tenant_id, access_controller=ANY) def test_rebuild_single_file_should_raise_when_file_value_not_dict( self, @@ -2165,7 +2204,7 @@ class TestRebuildSingleFile: # Assert assert result is mock_files - mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id) + mock_build.assert_called_once_with(mappings=value, tenant_id=tenant_id, access_controller=ANY) def test_rebuild_single_file_should_raise_when_file_list_value_not_list( self, @@ -2279,13 +2318,12 @@ class TestWorkflowServiceResolveDeliveryMethod: # Arrange method_a = self._make_method("method-1") method_b = self._make_method("method-2") - node_data = MagicMock() - node_data.delivery_methods = [method_a, method_b] # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="method-2" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a, method_b]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="method-2" + ) # Assert assert result is method_b @@ -2293,26 +2331,22 @@ class TestWorkflowServiceResolveDeliveryMethod: def test_resolve_delivery_method_should_return_none_when_no_match(self) -> None: # Arrange method_a = self._make_method("method-1") - node_data = MagicMock() - node_data.delivery_methods = [method_a] # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="does-not-exist" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[method_a]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="does-not-exist" + ) # Assert assert result is None def test_resolve_delivery_method_should_return_none_for_empty_methods(self) -> None: - # Arrange - node_data = MagicMock() - node_data.delivery_methods = [] - # Act - result = WorkflowService._resolve_human_input_delivery_method( - node_data=node_data, delivery_method_id="method-1" - ) + with patch("services.workflow_service.parse_human_input_delivery_methods", return_value=[]): + result = WorkflowService._resolve_human_input_delivery_method( + node_data=MagicMock(), delivery_method_id="method-1" + ) # Assert assert result is None @@ -2435,6 +2469,9 @@ class TestWorkflowServiceDraftExecution: patch("services.workflow_service.Session"), patch("services.workflow_service.WorkflowDraftVariableService"), patch("services.workflow_service.VariablePool") as mock_pool_cls, + patch("services.workflow_service.default_system_variables") as mock_default_system_variables, + patch("services.workflow_service.build_bootstrap_variables") as mock_build_bootstrap_variables, + patch("services.workflow_service.add_variables_to_pool") as mock_add_variables_to_pool, patch("services.workflow_service.DraftVarLoader"), patch("services.workflow_service.WorkflowEntry.single_step_run") as mock_run, patch("services.workflow_service.DifyCoreRepositoryFactory"), @@ -2475,10 +2512,16 @@ class TestWorkflowServiceDraftExecution: ) # Assert - # For non-start nodes, VariablePool should be initialized with environment_variables - mock_pool_cls.assert_called_once() - args, kwargs = mock_pool_cls.call_args - assert "environment_variables" in kwargs + # For non-start nodes, bootstrap variables should be loaded into an empty pool. + mock_pool_cls.assert_called_once_with() + mock_default_system_variables.assert_called_once() + mock_build_bootstrap_variables.assert_called_once_with( + system_variables=mock_default_system_variables.return_value, + environment_variables=draft_workflow.environment_variables, + ) + mock_add_variables_to_pool.assert_called_once_with( + mock_pool_cls.return_value, mock_build_bootstrap_variables.return_value + ) # =========================================================================== @@ -2588,7 +2631,7 @@ class TestWorkflowServiceHumanInputOperations: patch("models.workflow.Workflow.get_node_type_from_node_config", return_value=BuiltinNodeTypes.HUMAN_INPUT), patch("services.workflow_service.HumanInputNodeData.model_validate"), patch.object(service, "_resolve_human_input_delivery_method") as mock_resolve, - patch("services.workflow_service.apply_debug_email_recipient"), + patch("services.workflow_service.apply_dify_debug_email_recipient"), patch.object(service, "_build_human_input_variable_pool"), patch.object(service, "_build_human_input_node"), patch.object(service, "_create_human_input_delivery_test_form", return_value=("form-1", [])), @@ -2708,7 +2751,7 @@ class TestWorkflowServiceFreeNodeExecution: def test_validate_human_input_node_data_error(self, service: WorkflowService) -> None: with patch( - "dify_graph.nodes.human_input.entities.HumanInputNodeData.model_validate", side_effect=Exception("error") + "graphon.nodes.human_input.entities.HumanInputNodeData.model_validate", side_effect=Exception("error") ): with pytest.raises(ValueError, match="Invalid HumanInput node data"): service._validate_human_input_node_data({}) @@ -2730,13 +2773,15 @@ class TestWorkflowServiceFreeNodeExecution: variable_pool = MagicMock() with ( - patch("services.workflow_service.GraphInitParams"), + patch("services.workflow_service.GraphInitParams") as mock_graph_init_params, patch("services.workflow_service.GraphRuntimeState"), + patch("services.workflow_service.build_dify_run_context"), + patch("services.workflow_service.DifyHumanInputNodeRuntime") as mock_runtime_cls, patch("services.workflow_service.HumanInputNode") as mock_node_cls, - patch("services.workflow_service.HumanInputFormRepositoryImpl"), ): node = service._build_human_input_node( workflow=workflow, account=account, node_config=node_config, variable_pool=variable_pool ) assert node == mock_node_cls.return_value mock_node_cls.assert_called_once() + mock_runtime_cls.assert_called_once_with(mock_graph_init_params.return_value.run_context) diff --git a/api/tests/unit_tests/services/test_workspace_service.py b/api/tests/unit_tests/services/test_workspace_service.py deleted file mode 100644 index 9bfd7eb2c5..0000000000 --- a/api/tests/unit_tests/services/test_workspace_service.py +++ /dev/null @@ -1,576 +0,0 @@ -from __future__ import annotations - -from types import SimpleNamespace -from typing import Any, cast -from unittest.mock import MagicMock - -import pytest -from pytest_mock import MockerFixture - -from models.account import Tenant - -# --------------------------------------------------------------------------- -# Constants used throughout the tests -# --------------------------------------------------------------------------- - -TENANT_ID = "tenant-abc" -ACCOUNT_ID = "account-xyz" -FILES_BASE_URL = "https://files.example.com" - -DB_PATH = "services.workspace_service.db" -FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features" -TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles" -DIFY_CONFIG_PATH = "services.workspace_service.dify_config" -CURRENT_USER_PATH = "services.workspace_service.current_user" -CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool" - - -# --------------------------------------------------------------------------- -# Helpers / factories -# --------------------------------------------------------------------------- - - -def _make_tenant( - tenant_id: str = TENANT_ID, - name: str = "My Workspace", - plan: str = "sandbox", - status: str = "active", - custom_config: dict | None = None, -) -> Tenant: - """Create a minimal Tenant-like namespace.""" - return cast( - Tenant, - SimpleNamespace( - id=tenant_id, - name=name, - plan=plan, - status=status, - created_at="2024-01-01T00:00:00Z", - custom_config_dict=custom_config or {}, - ), - ) - - -def _make_feature( - can_replace_logo: bool = False, - next_credit_reset_date: str | None = None, - billing_plan: str = "sandbox", -) -> MagicMock: - """Create a feature namespace matching what FeatureService.get_features returns.""" - feature = MagicMock() - feature.can_replace_logo = can_replace_logo - feature.next_credit_reset_date = next_credit_reset_date - feature.billing.subscription.plan = billing_plan - return feature - - -def _make_pool(quota_limit: int, quota_used: int) -> MagicMock: - pool = MagicMock() - pool.quota_limit = quota_limit - pool.quota_used = quota_used - return pool - - -def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace: - return SimpleNamespace(role=role) - - -def _tenant_info(result: object) -> dict[str, Any] | None: - return cast(dict[str, Any] | None, result) - - -# --------------------------------------------------------------------------- -# Shared fixtures -# --------------------------------------------------------------------------- - - -@pytest.fixture -def mock_current_user() -> SimpleNamespace: - """Return a lightweight current_user stand-in.""" - return SimpleNamespace(id=ACCOUNT_ID) - - -@pytest.fixture -def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: - """ - Patch the common external boundaries used by WorkspaceService.get_tenant_info. - - Returns a dict of named mocks so individual tests can customise them. - """ - mocker.patch(CURRENT_USER_PATH, mock_current_user) - - mock_db_session = mocker.patch(f"{DB_PATH}.session") - mock_query_chain = MagicMock() - mock_db_session.query.return_value = mock_query_chain - mock_query_chain.where.return_value = mock_query_chain - mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") - - mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature()) - mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False) - mock_config = mocker.patch(DIFY_CONFIG_PATH) - mock_config.EDITION = "SELF_HOSTED" - mock_config.FILES_URL = FILES_BASE_URL - - return { - "db_session": mock_db_session, - "query_chain": mock_query_chain, - "get_features": mock_feature, - "has_roles": mock_has_roles, - "config": mock_config, - } - - -# --------------------------------------------------------------------------- -# 1. None Tenant Handling -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None: - """get_tenant_info should short-circuit and return None for a falsy tenant.""" - from services.workspace_service import WorkspaceService - - # Arrange - tenant = None - - # Act - result = WorkspaceService.get_tenant_info(cast(Tenant, tenant)) - - # Assert - assert result is None - - -def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None: - """get_tenant_info treats any falsy value as absent (e.g. empty string, 0).""" - from services.workspace_service import WorkspaceService - - # Arrange / Act / Assert - assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type] - - -# --------------------------------------------------------------------------- -# 2. Basic Tenant Info — happy path -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_return_base_fields( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """get_tenant_info should always return the six base scalar fields.""" - from services.workspace_service import WorkspaceService - - # Arrange - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["id"] == TENANT_ID - assert result["name"] == "My Workspace" - assert result["plan"] == "sandbox" - assert result["status"] == "active" - assert result["created_at"] == "2024-01-01T00:00:00Z" - assert result["trial_end_reason"] is None - - -def test_get_tenant_info_should_populate_role_from_tenant_account_join( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """The 'role' field should be taken from TenantAccountJoin, not the default.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin") - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["role"] == "admin" - - -def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """ - The service asserts that TenantAccountJoin exists. - Missing join should raise AssertionError. - """ - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["query_chain"].first.return_value = None - tenant = _make_tenant() - - # Act + Assert - with pytest.raises(AssertionError, match="TenantAccountJoin not found"): - WorkspaceService.get_tenant_info(tenant) - - -# --------------------------------------------------------------------------- -# 3. Logo Customisation -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config block should appear for OWNER/ADMIN when can_replace_logo is True.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant( - custom_config={ - "replace_webapp_logo": True, - "remove_webapp_brand": True, - } - ) - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" in result - assert result["custom_config"]["remove_webapp_brand"] is True - expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo" - assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url - - -def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """replace_webapp_logo should be None when custom_config_dict does not have the key.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["custom_config"]["replace_webapp_logo"] is None - - -def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config should be absent when can_replace_logo is False.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" not in result - - -def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """custom_config block is gated on OWNER or ADMIN role.""" - from services.workspace_service import WorkspaceService - - # Arrange - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = False # regular member - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "custom_config" not in result - - -def test_get_tenant_info_should_use_files_url_for_logo_url( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """The logo URL should use dify_config.FILES_URL as the base.""" - from services.workspace_service import WorkspaceService - - # Arrange - custom_base = "https://cdn.mycompany.io" - basic_mocks["config"].FILES_URL = custom_base - basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True) - basic_mocks["has_roles"].return_value = True - tenant = _make_tenant(custom_config={"replace_webapp_logo": True}) - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base) - - -# --------------------------------------------------------------------------- -# 4. Cloud-Edition Credit Features -# --------------------------------------------------------------------------- - -CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX - - -@pytest.fixture -def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict: - """Patches for CLOUD edition tests, billing plan = professional by default.""" - mocker.patch(CURRENT_USER_PATH, mock_current_user) - - mock_db_session = mocker.patch(f"{DB_PATH}.session") - mock_query_chain = MagicMock() - mock_db_session.query.return_value = mock_query_chain - mock_query_chain.where.return_value = mock_query_chain - mock_query_chain.first.return_value = _make_tenant_account_join(role="owner") - - mock_feature = mocker.patch( - FEATURE_SERVICE_PATH, - return_value=_make_feature( - can_replace_logo=False, - next_credit_reset_date="2025-02-01", - billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX, - ), - ) - mocker.patch(TENANT_SERVICE_PATH, return_value=False) - mock_config = mocker.patch(DIFY_CONFIG_PATH) - mock_config.EDITION = "CLOUD" - mock_config.FILES_URL = FILES_BASE_URL - - return { - "db_session": mock_db_session, - "query_chain": mock_query_chain, - "get_features": mock_feature, - "config": mock_config, - } - - -def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """next_credit_reset_date should be present in CLOUD edition.""" - from services.workspace_service import WorkspaceService - - # Arrange - mocker.patch( - CREDIT_POOL_SERVICE_PATH, - side_effect=[None, None], # both paid and trial pools absent - ) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["next_credit_reset_date"] == "2025-02-01" - - -def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """trial_credits/trial_credits_used come from the paid pool when conditions are met.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=1000, quota_used=200) - mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 1000 - assert result["trial_credits_used"] == 200 - - -def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """quota_limit == -1 means unlimited; service should still use the paid pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=-1, quota_used=999) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == -1 - assert result["trial_credits_used"] == 999 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When paid pool is exhausted (used >= limit), switch to trial pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full - trial_pool = _make_pool(quota_limit=100, quota_used=10) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 100 - assert result["trial_credits_used"] == 10 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When paid_pool is None, fall back to trial pool.""" - from services.workspace_service import WorkspaceService - - # Arrange - trial_pool = _make_pool(quota_limit=50, quota_used=5) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 50 - assert result["trial_credits_used"] == 5 - - -def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """ - When the subscription plan IS SANDBOX, the paid pool branch is skipped - entirely and we fall back to the trial pool. - """ - from enums.cloud_plan import CloudPlan - from services.workspace_service import WorkspaceService - - # Arrange — override billing plan to SANDBOX - cloud_mocks["get_features"].return_value = _make_feature( - next_credit_reset_date="2025-02-01", - billing_plan=CloudPlan.SANDBOX, - ) - paid_pool = _make_pool(quota_limit=1000, quota_used=0) - trial_pool = _make_pool(quota_limit=200, quota_used=20) - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert result["trial_credits"] == 200 - assert result["trial_credits_used"] == 20 - - -def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none( - mocker: MockerFixture, - cloud_mocks: dict, -) -> None: - """When both paid and trial pools are absent, trial_credits should not be set.""" - from services.workspace_service import WorkspaceService - - # Arrange - mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None]) - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "trial_credits" not in result - assert "trial_credits_used" not in result - - -# --------------------------------------------------------------------------- -# 5. Self-hosted / Non-Cloud Edition -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode.""" - from services.workspace_service import WorkspaceService - - # Arrange (basic_mocks already sets EDITION = "SELF_HOSTED") - tenant = _make_tenant() - - # Act - result = _tenant_info(WorkspaceService.get_tenant_info(tenant)) - - # Assert - assert result is not None - assert "next_credit_reset_date" not in result - assert "trial_credits" not in result - assert "trial_credits_used" not in result - - -# --------------------------------------------------------------------------- -# 6. DB query integrity -# --------------------------------------------------------------------------- - - -def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids( - mocker: MockerFixture, - basic_mocks: dict, -) -> None: - """ - The DB query for TenantAccountJoin must be scoped to the correct - tenant_id and current_user.id. - """ - from services.workspace_service import WorkspaceService - - # Arrange - tenant = _make_tenant(tenant_id="my-special-tenant") - mock_current_user = mocker.patch(CURRENT_USER_PATH) - mock_current_user.id = "special-user-id" - - # Act - WorkspaceService.get_tenant_info(tenant) - - # Assert — db.session.query was invoked (at least once) - basic_mocks["db_session"].query.assert_called() diff --git a/api/tests/unit_tests/services/vector_service.py b/api/tests/unit_tests/services/vector_service.py index 33a5607ef4..ee9ba1c6d6 100644 --- a/api/tests/unit_tests/services/vector_service.py +++ b/api/tests/unit_tests/services/vector_service.py @@ -522,7 +522,7 @@ class TestVectorService: assert call_args[1]["keywords_list"] == keywords_list @patch("services.vector_service.VectorService.generate_child_chunks") - @patch("services.vector_service.ModelManager") + @patch("services.vector_service.ModelManager.for_tenant") @patch("services.vector_service.db") def test_create_segments_vector_parent_child_indexing( self, mock_db, mock_model_manager, mock_generate_child_chunks @@ -1755,7 +1755,7 @@ class TestVector: # ======================================================================== @patch("core.rag.datasource.vdb.vector_factory.CacheEmbedding") - @patch("core.rag.datasource.vdb.vector_factory.ModelManager") + @patch("core.rag.datasource.vdb.vector_factory.ModelManager.for_tenant") @patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector") def test_vector_get_embeddings(self, mock_init_vector, mock_model_manager, mock_cache_embedding): """ diff --git a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py index f3391d6380..8525672da8 100644 --- a/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py +++ b/api/tests/unit_tests/services/workflow/test_draft_var_loader_simple.py @@ -4,10 +4,12 @@ import json from unittest.mock import Mock, patch import pytest +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import ObjectSegment, StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine -from dify_graph.variables.segments import ObjectSegment, StringSegment -from dify_graph.variables.types import SegmentType +from core.workflow.file_reference import build_file_reference from models.model import UploadFile from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile from services.workflow_draft_variable_service import DraftVarLoader @@ -54,25 +56,18 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_content.encode() - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_variable" - mock_variable.value = StringSegment(value=test_content) - mock_segment_to_variable.return_value = mock_variable + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + # Verify results + assert selector_tuple == ("test-node-id", "test_variable") + assert variable.id == "draft-var-id" + assert variable.name == "test_variable" + assert variable.description == "test description" + assert variable.value == test_content - # Verify results - assert selector_tuple == ("test-node-id", "test_variable") - assert variable.id == "draft-var-id" - assert variable.name == "test_variable" - assert variable.description == "test description" - assert variable.value == test_content - - # Verify storage was called correctly - mock_storage.load.assert_called_once_with("storage/key/test.txt") + # Verify storage was called correctly + mock_storage.load.assert_called_once_with("storage/key/test.txt") def test_load_offloaded_variable_object_type_unit(self, draft_var_loader): """Test _load_offloaded_variable with object type - isolated unit test.""" @@ -97,31 +92,22 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + mock_segment = ObjectSegment(value=test_object) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - mock_segment = ObjectSegment(value=test_object) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_object" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_object") + assert variable.id == "draft-var-id" + assert variable.name == "test_object" + assert variable.description == "test description" + assert variable.value == test_object - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - - # Verify results - assert selector_tuple == ("test-node-id", "test_object") - assert variable.id == "draft-var-id" - assert variable.name == "test_object" - assert variable.description == "test description" - assert variable.value == test_object - - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test.json") - mock_build_segment.assert_called_once_with(SegmentType.OBJECT, test_object) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.OBJECT, test_object) def test_load_offloaded_variable_missing_variable_file_unit(self, draft_var_loader): """Test that assertion error is raised when variable_file is None.""" @@ -176,32 +162,23 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + from graphon.variables.segments import FloatSegment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from dify_graph.variables.segments import FloatSegment + mock_segment = FloatSegment(value=test_number) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - mock_segment = FloatSegment(value=test_number) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_number" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_number") + assert variable.id == "draft-var-id" + assert variable.name == "test_number" + assert variable.description == "test number description" - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - - # Verify results - assert selector_tuple == ("test-node-id", "test_number") - assert variable.id == "draft-var-id" - assert variable.name == "test_number" - assert variable.description == "test number description" - - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test_number.json") - mock_build_segment.assert_called_once_with(SegmentType.NUMBER, test_number) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test_number.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.NUMBER, test_number) def test_load_offloaded_variable_array_type_unit(self, draft_var_loader): """Test _load_offloaded_variable with array type - isolated unit test.""" @@ -226,32 +203,83 @@ class TestDraftVarLoaderSimple: with patch("services.workflow_draft_variable_service.storage") as mock_storage: mock_storage.load.return_value = test_json_content.encode() + from graphon.variables.segments import ArrayAnySegment - with patch.object(WorkflowDraftVariable, "build_segment_with_type") as mock_build_segment: - from dify_graph.variables.segments import ArrayAnySegment + mock_segment = ArrayAnySegment(value=test_array) + draft_var.build_segment_from_serialized_value.return_value = mock_segment - mock_segment = ArrayAnySegment(value=test_array) - mock_build_segment.return_value = mock_segment + # Execute the method + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) - with patch("factories.variable_factory.segment_to_variable") as mock_segment_to_variable: - mock_variable = Mock() - mock_variable.id = "draft-var-id" - mock_variable.name = "test_array" - mock_variable.value = mock_segment - mock_segment_to_variable.return_value = mock_variable + # Verify results + assert selector_tuple == ("test-node-id", "test_array") + assert variable.id == "draft-var-id" + assert variable.name == "test_array" + assert variable.description == "test array description" - # Execute the method - selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + # Verify method calls + mock_storage.load.assert_called_once_with("storage/key/test_array.json") + draft_var.build_segment_from_serialized_value.assert_called_once_with(SegmentType.ARRAY_ANY, test_array) - # Verify results - assert selector_tuple == ("test-node-id", "test_array") - assert variable.id == "draft-var-id" - assert variable.name == "test_array" - assert variable.description == "test array description" + def test_load_offloaded_variable_file_type_rebuilds_storage_backed_payload(self, draft_var_loader): + upload_file = Mock(spec=UploadFile) + upload_file.key = "storage/key/test_file.json" - # Verify method calls - mock_storage.load.assert_called_once_with("storage/key/test_array.json") - mock_build_segment.assert_called_once_with(SegmentType.ARRAY_ANY, test_array) + variable_file = Mock(spec=WorkflowDraftVariableFile) + variable_file.value_type = SegmentType.FILE + variable_file.upload_file = upload_file + + draft_var = WorkflowDraftVariable() + draft_var.id = "draft-var-id" + draft_var.app_id = "app-1" + draft_var.node_id = "test-node-id" + draft_var.name = "test_file" + draft_var.description = "test file description" + draft_var._set_selector(["test-node-id", "test_file"]) + draft_var.variable_file = variable_file + + persisted_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-1", storage_key="legacy-storage-key"), + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + ) + rebuilt_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference=build_file_reference(record_id="upload-1"), + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + storage_key="canonical-storage-key", + ) + + raw_file = { + **persisted_file.model_dump(mode="json"), + "tenant_id": "legacy-tenant", + } + + with ( + patch("services.workflow_draft_variable_service.storage") as mock_storage, + patch("models.workflow._resolve_workflow_app_tenant_id", return_value="tenant-1"), + patch("models.workflow.build_file_from_stored_mapping", return_value=rebuilt_file) as rebuild_file, + ): + mock_storage.load.return_value = json.dumps(raw_file).encode() + + selector_tuple, variable = draft_var_loader._load_offloaded_variable(draft_var) + + assert selector_tuple == ("test-node-id", "test_file") + assert variable.id == "draft-var-id" + assert variable.name == "test_file" + assert variable.description == "test file description" + assert variable.value == rebuilt_file + rebuild_file.assert_called_once_with(file_mapping=raw_file, tenant_id="tenant-1") def test_load_variables_with_offloaded_variables_unit(self, draft_var_loader): """Test load_variables method with mix of regular and offloaded variables.""" diff --git a/api/tests/unit_tests/services/workflow/test_workflow_converter.py b/api/tests/unit_tests/services/workflow/test_workflow_converter.py deleted file mode 100644 index a847c2b4d1..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_converter.py +++ /dev/null @@ -1,431 +0,0 @@ -# test for api/services/workflow/workflow_converter.py -import json -from unittest.mock import MagicMock - -import pytest - -from core.app.app_config.entities import ( - AdvancedChatMessageEntity, - AdvancedChatPromptTemplateEntity, - AdvancedCompletionPromptTemplateEntity, - DatasetEntity, - DatasetRetrieveConfigEntity, - ExternalDataVariableEntity, - ModelConfigEntity, - PromptTemplateEntity, -) -from core.helper import encrypter -from dify_graph.model_runtime.entities.llm_entities import LLMMode -from dify_graph.model_runtime.entities.message_entities import PromptMessageRole -from dify_graph.variables.input_entities import VariableEntity, VariableEntityType -from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint -from models.model import AppMode -from services.workflow.workflow_converter import WorkflowConverter - - -@pytest.fixture -def default_variables(): - value = [ - VariableEntity( - variable="text_input", - label="text-input", - type=VariableEntityType.TEXT_INPUT, - ), - VariableEntity( - variable="paragraph", - label="paragraph", - type=VariableEntityType.PARAGRAPH, - ), - VariableEntity( - variable="select", - label="select", - type=VariableEntityType.SELECT, - ), - ] - return value - - -def test__convert_to_start_node(default_variables): - # act - result = WorkflowConverter()._convert_to_start_node(default_variables) - - # assert - assert isinstance(result["data"]["variables"][0]["type"], str) - assert result["data"]["variables"][0]["type"] == "text-input" - assert result["data"]["variables"][0]["variable"] == "text_input" - assert result["data"]["variables"][1]["variable"] == "paragraph" - assert result["data"]["variables"][2]["variable"] == "select" - - -def test__convert_to_http_request_node_for_chatbot(default_variables): - """ - Test convert to http request nodes for chatbot - :return: - """ - app_model = MagicMock() - app_model.id = "app_id" - app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.CHAT - - api_based_extension_id = "api_based_extension_id" - mock_api_based_extension = APIBasedExtension( - tenant_id="tenant_id", - name="api-1", - api_key="encrypted_api_key", - api_endpoint="https://dify.ai", - ) - - mock_api_based_extension.id = api_based_extension_id - workflow_converter = WorkflowConverter() - workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) - - encrypter.decrypt_token = MagicMock(return_value="api_key") - - external_data_variables = [ - ExternalDataVariableEntity( - variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} - ) - ] - - nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, variables=default_variables, external_data_variables=external_data_variables - ) - - assert len(nodes) == 2 - assert nodes[0]["data"]["type"] == "http-request" - - http_request_node = nodes[0] - - assert http_request_node["data"]["method"] == "post" - assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint - assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} - assert http_request_node["data"]["body"]["type"] == "json" - - body_data = http_request_node["data"]["body"]["data"] - - assert body_data - - body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY - - body_params = body_data_json["params"] - assert body_params["app_id"] == app_model.id - assert body_params["tool_variable"] == external_data_variables[0].variable - assert len(body_params["inputs"]) == 3 - assert body_params["query"] == "{{#sys.query#}}" # for chatbot - - code_node = nodes[1] - assert code_node["data"]["type"] == "code" - - -def test__convert_to_http_request_node_for_workflow_app(default_variables): - """ - Test convert to http request nodes for workflow app - :return: - """ - app_model = MagicMock() - app_model.id = "app_id" - app_model.tenant_id = "tenant_id" - app_model.mode = AppMode.WORKFLOW - - api_based_extension_id = "api_based_extension_id" - mock_api_based_extension = APIBasedExtension( - tenant_id="tenant_id", - name="api-1", - api_key="encrypted_api_key", - api_endpoint="https://dify.ai", - ) - mock_api_based_extension.id = api_based_extension_id - - workflow_converter = WorkflowConverter() - workflow_converter._get_api_based_extension = MagicMock(return_value=mock_api_based_extension) - - encrypter.decrypt_token = MagicMock(return_value="api_key") - - external_data_variables = [ - ExternalDataVariableEntity( - variable="external_variable", type="api", config={"api_based_extension_id": api_based_extension_id} - ) - ] - - nodes, _ = workflow_converter._convert_to_http_request_node( - app_model=app_model, variables=default_variables, external_data_variables=external_data_variables - ) - - assert len(nodes) == 2 - assert nodes[0]["data"]["type"] == "http-request" - - http_request_node = nodes[0] - - assert http_request_node["data"]["method"] == "post" - assert http_request_node["data"]["url"] == mock_api_based_extension.api_endpoint - assert http_request_node["data"]["authorization"]["type"] == "api-key" - assert http_request_node["data"]["authorization"]["config"] == {"type": "bearer", "api_key": "api_key"} - assert http_request_node["data"]["body"]["type"] == "json" - - body_data = http_request_node["data"]["body"]["data"] - - assert body_data - - body_data_json = json.loads(body_data) - assert body_data_json["point"] == APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY - - body_params = body_data_json["params"] - assert body_params["app_id"] == app_model.id - assert body_params["tool_variable"] == external_data_variables[0].variable - assert len(body_params["inputs"]) == 3 - assert body_params["query"] == "" - - code_node = nodes[1] - assert code_node["data"]["type"] == "code" - - -def test__convert_to_knowledge_retrieval_node_for_chatbot(): - new_app_mode = AppMode.ADVANCED_CHAT - - dataset_config = DatasetEntity( - dataset_ids=["dataset_id_1", "dataset_id_2"], - retrieve_config=DatasetRetrieveConfigEntity( - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, - top_k=5, - score_threshold=0.8, - reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, - reranking_enabled=True, - ), - ) - - model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) - - node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config - ) - assert node is not None - - assert node["data"]["type"] == "knowledge-retrieval" - assert node["data"]["query_variable_selector"] == ["sys", "query"] - assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value - assert node["data"]["multiple_retrieval_config"] == { - "top_k": dataset_config.retrieve_config.top_k, - "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model, - } - - -def test__convert_to_knowledge_retrieval_node_for_workflow_app(): - new_app_mode = AppMode.WORKFLOW - - dataset_config = DatasetEntity( - dataset_ids=["dataset_id_1", "dataset_id_2"], - retrieve_config=DatasetRetrieveConfigEntity( - query_variable="query", - retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.MULTIPLE, - top_k=5, - score_threshold=0.8, - reranking_model={"reranking_provider_name": "cohere", "reranking_model_name": "rerank-english-v2.0"}, - reranking_enabled=True, - ), - ) - - model_config = ModelConfigEntity(provider="openai", model="gpt-4", mode="chat", parameters={}, stop=[]) - - node = WorkflowConverter()._convert_to_knowledge_retrieval_node( - new_app_mode=new_app_mode, dataset_config=dataset_config, model_config=model_config - ) - assert node is not None - - assert node["data"]["type"] == "knowledge-retrieval" - assert node["data"]["query_variable_selector"] == ["start", dataset_config.retrieve_config.query_variable] - assert node["data"]["dataset_ids"] == dataset_config.dataset_ids - assert node["data"]["retrieval_mode"] == dataset_config.retrieve_config.retrieve_strategy.value - assert node["data"]["multiple_retrieval_config"] == { - "top_k": dataset_config.retrieve_config.top_k, - "score_threshold": dataset_config.retrieve_config.score_threshold, - "reranking_model": dataset_config.retrieve_config.reranking_model, - } - - -def test__convert_to_llm_node_for_chatbot_simple_chat_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-4" - model_mode = LLMMode.CHAT - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - template = prompt_template.simple_prompt_template - assert template is not None - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"][0]["text"] == template + "\n" - assert llm_node["data"]["context"]["enabled"] is False - - -def test__convert_to_llm_node_for_chatbot_simple_completion_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-3.5-turbo-instruct" - model_mode = LLMMode.COMPLETION - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.SIMPLE, - simple_prompt_template="You are a helpful assistant {{text_input}}, {{paragraph}}, {{select}}.", - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - template = prompt_template.simple_prompt_template - assert template is not None - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"]["text"] == template + "\n" - assert llm_node["data"]["context"]["enabled"] is False - - -def test__convert_to_llm_node_for_chatbot_advanced_chat_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-4" - model_mode = LLMMode.CHAT - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_chat_prompt_template=AdvancedChatPromptTemplateEntity( - messages=[ - AdvancedChatMessageEntity( - text="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}", - role=PromptMessageRole.SYSTEM, - ), - AdvancedChatMessageEntity(text="Hi.", role=PromptMessageRole.USER), - AdvancedChatMessageEntity(text="Hello!", role=PromptMessageRole.ASSISTANT), - ] - ), - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - assert isinstance(llm_node["data"]["prompt_template"], list) - assert prompt_template.advanced_chat_prompt_template is not None - assert len(llm_node["data"]["prompt_template"]) == len(prompt_template.advanced_chat_prompt_template.messages) - template = prompt_template.advanced_chat_prompt_template.messages[0].text - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"][0]["text"] == template - - -def test__convert_to_llm_node_for_workflow_advanced_completion_model(default_variables): - new_app_mode = AppMode.ADVANCED_CHAT - model = "gpt-3.5-turbo-instruct" - model_mode = LLMMode.COMPLETION - - workflow_converter = WorkflowConverter() - start_node = workflow_converter._convert_to_start_node(default_variables) - graph = { - "nodes": [start_node], - "edges": [], # no need - } - - model_config_mock = MagicMock(spec=ModelConfigEntity) - model_config_mock.provider = "openai" - model_config_mock.model = model - model_config_mock.mode = model_mode.value - model_config_mock.parameters = {} - model_config_mock.stop = [] - - prompt_template = PromptTemplateEntity( - prompt_type=PromptTemplateEntity.PromptType.ADVANCED, - advanced_completion_prompt_template=AdvancedCompletionPromptTemplateEntity( - prompt="You are a helpful assistant named {{name}}.\n\nContext:\n{{#context#}}\n\nHuman: hi\nAssistant: ", - role_prefix=AdvancedCompletionPromptTemplateEntity.RolePrefixEntity(user="Human", assistant="Assistant"), - ), - ) - - llm_node = workflow_converter._convert_to_llm_node( - original_app_mode=AppMode.CHAT, - new_app_mode=new_app_mode, - model_config=model_config_mock, - graph=graph, - prompt_template=prompt_template, - ) - - assert llm_node["data"]["type"] == "llm" - assert llm_node["data"]["model"]["name"] == model - assert llm_node["data"]["model"]["mode"] == model_mode.value - assert isinstance(llm_node["data"]["prompt_template"], dict) - assert prompt_template.advanced_completion_prompt_template is not None - template = prompt_template.advanced_completion_prompt_template.prompt - for v in default_variables: - template = template.replace("{{" + v.variable + "}}", "{{#start." + v.variable + "#}}") - assert llm_node["data"]["prompt_template"]["text"] == template diff --git a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py index 0c2be9c79f..e7e72793a3 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_draft_variable_service.py @@ -4,13 +4,19 @@ import uuid from unittest.mock import MagicMock, Mock, patch import pytest +from graphon.enums import BuiltinNodeTypes +from graphon.file import File, FileTransferMethod, FileType +from graphon.variables.segments import StringSegment +from graphon.variables.types import SegmentType from sqlalchemy import Engine from sqlalchemy.orm import Session -from dify_graph.constants import SYSTEM_VARIABLE_NODE_ID -from dify_graph.enums import BuiltinNodeTypes, SystemVariableKey -from dify_graph.variables.segments import StringSegment -from dify_graph.variables.types import SegmentType +from core.workflow.system_variables import SystemVariableKey +from core.workflow.variable_prefixes import ( + CONVERSATION_VARIABLE_NODE_ID, + ENVIRONMENT_VARIABLE_NODE_ID, + SYSTEM_VARIABLE_NODE_ID, +) from libs.uuid_utils import uuidv7 from models.account import Account from models.enums import DraftVariableType @@ -86,6 +92,20 @@ class TestDraftVariableSaver: expected_node_id=_NODE_ID, expected_name="start_input", ), + TestCase( + name="name with `env.` prefix should return the environment node_id", + input_node_id=_NODE_ID, + input_name="env.API_KEY", + expected_node_id=ENVIRONMENT_VARIABLE_NODE_ID, + expected_name="API_KEY", + ), + TestCase( + name="name with `conversation.` prefix should return the conversation node_id", + input_node_id=_NODE_ID, + input_name="conversation.session_id", + expected_node_id=CONVERSATION_VARIABLE_NODE_ID, + expected_name="session_id", + ), TestCase( name="dummy_variable should return the original input node_id", input_node_id=_NODE_ID, @@ -112,6 +132,47 @@ class TestDraftVariableSaver: assert node_id == c.expected_node_id, fail_msg assert name == c.expected_name, fail_msg + def test_build_variables_from_start_mapping_rebuilds_system_files(self): + mock_session = MagicMock(spec=Session) + mock_user = MagicMock(spec=Account) + mock_user.id = str(uuid.uuid4()) + saver = DraftVariableSaver( + session=mock_session, + app_id=self._get_test_app_id(), + node_id="start", + node_type=BuiltinNodeTypes.START, + node_execution_id="exec-1", + user=mock_user, + ) + rebuilt_file = File( + id="file-1", + type=FileType.DOCUMENT, + transfer_method=FileTransferMethod.LOCAL_FILE, + reference="upload-1", + filename="test.txt", + extension=".txt", + mime_type="text/plain", + size=12, + storage_key="canonical-storage-key", + ) + raw_file = { + **rebuilt_file.model_dump(mode="json"), + "tenant_id": "legacy-tenant", + } + + with ( + patch.object(saver, "_resolve_app_tenant_id", return_value="tenant-1"), + patch( + "services.workflow_draft_variable_service.build_file_from_stored_mapping", + return_value=rebuilt_file, + ) as rebuild_file, + ): + draft_vars = saver._build_variables_from_start_mapping({"sys.files": [raw_file]}) + + sys_var = draft_vars[0] + assert sys_var.get_value().value[0] == rebuilt_file + rebuild_file.assert_called_once_with(file_mapping=raw_file, tenant_id="tenant-1") + @pytest.fixture def mock_session(self): """Mock SQLAlchemy session.""" @@ -218,6 +279,46 @@ class TestDraftVariableSaver: str(SystemVariableKey.WORKFLOW_EXECUTION_ID), } + @patch("services.workflow_draft_variable_service._batch_upsert_draft_variable", autospec=True) + def test_start_node_save_normalizes_reserved_prefix_outputs(self, mock_batch_upsert): + mock_session = MagicMock(spec=Session) + mock_user = MagicMock(spec=Account) + mock_user.id = "test-user-id" + mock_user.tenant_id = "test-tenant-id" + + saver = DraftVariableSaver( + session=mock_session, + app_id="test-app-id", + node_id="start-node-id", + node_type=BuiltinNodeTypes.START, + node_execution_id="exec-id", + user=mock_user, + ) + + saver.save( + outputs={ + "env.API_KEY": "secret", + "conversation.session_id": "conversation-1", + "sys.workflow_run_id": "run-id-123", + } + ) + + mock_batch_upsert.assert_called_once() + draft_vars = mock_batch_upsert.call_args[0][1] + + assert len(draft_vars) == 3 + + env_var = next(v for v in draft_vars if v.node_id == ENVIRONMENT_VARIABLE_NODE_ID) + assert env_var.name == "API_KEY" + assert env_var.editable is False + + conversation_var = next(v for v in draft_vars if v.node_id == CONVERSATION_VARIABLE_NODE_ID) + assert conversation_var.name == "session_id" + assert conversation_var.node_execution_id is None + + sys_var = next(v for v in draft_vars if v.node_id == SYSTEM_VARIABLE_NODE_ID) + assert sys_var.name == str(SystemVariableKey.WORKFLOW_EXECUTION_ID) + class TestWorkflowDraftVariableService: def _get_test_app_id(self): diff --git a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py index 6c1adba2b8..077a7c27a2 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_event_snapshot_service.py @@ -8,13 +8,13 @@ from datetime import UTC, datetime from threading import Event import pytest +from graphon.entities.pause_reason import HumanInputRequired +from graphon.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus +from graphon.runtime import GraphRuntimeState, VariablePool from core.app.app_config.entities import WorkflowUIBasedAppConfig from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper -from dify_graph.entities.pause_reason import HumanInputRequired -from dify_graph.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus -from dify_graph.runtime import GraphRuntimeState, VariablePool from models.enums import CreatorUserRole from models.model import AppMode from models.workflow import WorkflowRun diff --git a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py index c890ab6a65..98d057e41f 100644 --- a/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py +++ b/api/tests/unit_tests/services/workflow/test_workflow_human_input_delivery.py @@ -3,16 +3,16 @@ from types import SimpleNamespace from unittest.mock import MagicMock import pytest +from graphon.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter +from graphon.enums import BuiltinNodeTypes +from graphon.nodes.human_input.entities import HumanInputNodeData from sqlalchemy.orm import sessionmaker -from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import ( +from core.workflow.human_input_compat import ( EmailDeliveryConfig, EmailDeliveryMethod, EmailRecipients, ExternalRecipient, - HumanInputNodeData, MemberRecipient, ) from services import workflow_service as workflow_service_module @@ -23,7 +23,7 @@ def _make_service() -> WorkflowService: return WorkflowService(session_maker=sessionmaker()) -def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfigDict: +def _build_node_config(delivery_methods: list[EmailDeliveryMethod], *, legacy: bool = False) -> NodeConfigDict: node_data = HumanInputNodeData( title="Human Input", delivery_methods=delivery_methods, @@ -31,6 +31,14 @@ def _build_node_config(delivery_methods: list[EmailDeliveryMethod]) -> NodeConfi inputs=[], user_actions=[], ).model_dump(mode="json") + if legacy: + for delivery_method in node_data["delivery_methods"]: + recipients = delivery_method.get("config", {}).get("recipients", {}) + if "include_bound_group" in recipients: + recipients["whole_workspace"] = recipients.pop("include_bound_group") + for recipient in recipients.get("items", []): + if "reference_id" in recipient: + recipient["user_id"] = recipient.pop("reference_id") node_data["type"] = BuiltinNodeTypes.HUMAN_INPUT return NodeConfigDictAdapter.validate_python({"id": "node-1", "data": node_data}) @@ -41,7 +49,7 @@ def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailD enabled=enabled, config=EmailDeliveryConfig( recipients=EmailRecipients( - whole_workspace=False, + include_bound_group=False, items=[ExternalRecipient(email="tester@example.com")], ), subject="Test subject", @@ -69,7 +77,7 @@ def test_human_input_delivery_requires_draft_workflow(): def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=False) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -105,7 +113,7 @@ def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyP def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=True) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -144,7 +152,7 @@ def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.Mon def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch): service = _make_service() delivery_method = _make_email_method(enabled=True, debug_mode=True) - node_config = _build_node_config([delivery_method]) + node_config = _build_node_config([delivery_method], legacy=True) workflow = MagicMock() workflow.get_node_config_by_id.return_value = node_config service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] @@ -178,8 +186,8 @@ def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytes sent_method = test_service_instance.send_test.call_args.kwargs["method"] assert isinstance(sent_method, EmailDeliveryMethod) assert sent_method.config.debug_mode is True - assert sent_method.config.recipients.whole_workspace is False + assert sent_method.config.recipients.include_bound_group is False assert len(sent_method.config.recipients.items) == 1 recipient = sent_method.config.recipients.items[0] assert isinstance(recipient, MemberRecipient) - assert recipient.user_id == account.id + assert recipient.reference_id == account.id diff --git a/api/tests/unit_tests/services/workflow/test_workflow_service.py b/api/tests/unit_tests/services/workflow/test_workflow_service.py deleted file mode 100644 index 538c1b3595..0000000000 --- a/api/tests/unit_tests/services/workflow/test_workflow_service.py +++ /dev/null @@ -1,415 +0,0 @@ -from contextlib import nullcontext -from types import SimpleNamespace -from unittest.mock import MagicMock - -import pytest - -from dify_graph.entities.graph_config import NodeConfigDictAdapter -from dify_graph.enums import BuiltinNodeTypes -from dify_graph.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction -from dify_graph.nodes.human_input.enums import FormInputType -from models.model import App -from models.workflow import Workflow -from services import workflow_service as workflow_service_module -from services.workflow_service import WorkflowService - - -class TestWorkflowService: - @pytest.fixture - def workflow_service(self): - mock_session_maker = MagicMock() - return WorkflowService(mock_session_maker) - - @pytest.fixture - def mock_app(self): - app = MagicMock(spec=App) - app.id = "app-id-1" - app.workflow_id = "workflow-id-1" - app.tenant_id = "tenant-id-1" - return app - - @pytest.fixture - def mock_workflows(self): - workflows = [] - for i in range(5): - workflow = MagicMock(spec=Workflow) - workflow.id = f"workflow-id-{i}" - workflow.app_id = "app-id-1" - workflow.created_at = f"2023-01-0{5 - i}" # Descending date order - workflow.created_by = "user-id-1" if i % 2 == 0 else "user-id-2" - workflow.marked_name = f"Workflow {i}" if i % 2 == 0 else "" - workflows.append(workflow) - return workflows - - @pytest.fixture - def dummy_session_cls(self): - class DummySession: - def __init__(self, *args, **kwargs): - self.commit = MagicMock() - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc, tb): - return False - - def begin(self): - return nullcontext() - - return DummySession - - def test_get_all_published_workflow_no_workflow_id(self, workflow_service, mock_app): - mock_app.workflow_id = None - mock_session = MagicMock() - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None - ) - - assert workflows == [] - assert has_more is False - mock_session.scalars.assert_not_called() - - def test_get_all_published_workflow_basic(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - mock_scalar_result.all.return_value = mock_workflows[:3] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None - ) - - assert workflows == mock_workflows[:3] - assert has_more is False - mock_session.scalars.assert_called_once() - - def test_get_all_published_workflow_pagination(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Return 4 items when limit is 3, which should indicate has_more=True - mock_scalar_result.all.return_value = mock_workflows[:4] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=3, user_id=None - ) - - # Should return only the first 3 items - assert len(workflows) == 3 - assert workflows == mock_workflows[:3] - assert has_more is True - - # Test page 2 - mock_scalar_result.all.return_value = mock_workflows[3:] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=2, limit=3, user_id=None - ) - - assert len(workflows) == 2 - assert has_more is False - - def test_get_all_published_workflow_user_filter(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Filter workflows for user-id-1 - filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1"] - mock_scalar_result.all.return_value = filtered_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1" - ) - - assert workflows == filtered_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that the select contains a user filter clause - args = mock_session.scalars.call_args[0][0] - assert "created_by" in str(args) - - def test_get_all_published_workflow_named_only(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Filter workflows that have a marked_name - named_workflows = [w for w in mock_workflows if w.marked_name] - mock_scalar_result.all.return_value = named_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None, named_only=True - ) - - assert workflows == named_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that the select contains a named_only filter clause - args = mock_session.scalars.call_args[0][0] - assert "marked_name !=" in str(args) - - def test_get_all_published_workflow_combined_filters(self, workflow_service, mock_app, mock_workflows): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - # Combined filter: user-id-1 and has marked_name - filtered_workflows = [w for w in mock_workflows if w.created_by == "user-id-1" and w.marked_name] - mock_scalar_result.all.return_value = filtered_workflows - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id="user-id-1", named_only=True - ) - - assert workflows == filtered_workflows - assert has_more is False - mock_session.scalars.assert_called_once() - - # Verify that both filters are applied - args = mock_session.scalars.call_args[0][0] - assert "created_by" in str(args) - assert "marked_name !=" in str(args) - - def test_get_all_published_workflow_empty_result(self, workflow_service, mock_app): - mock_session = MagicMock() - mock_scalar_result = MagicMock() - mock_scalar_result.all.return_value = [] - mock_session.scalars.return_value = mock_scalar_result - - workflows, has_more = workflow_service.get_all_published_workflow( - session=mock_session, app_model=mock_app, page=1, limit=10, user_id=None - ) - - assert workflows == [] - assert has_more is False - mock_session.scalars.assert_called_once() - - def test_submit_human_input_form_preview_uses_rendered_content( - self, - workflow_service: WorkflowService, - monkeypatch: pytest.MonkeyPatch, - dummy_session_cls, - ) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node.render_form_content_before_submission.return_value = "

preview

" - node.render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - node_config = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} - ) - workflow.get_node_config_by_id.return_value = node_config - workflow.get_enclosing_node_type_and_id.return_value = None - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - saved_outputs: dict[str, object] = {} - - class DummySaver: - def __init__(self, *args, **kwargs): - pass - - def save(self, outputs, process_data): - saved_outputs.update(outputs) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver) - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - - result = service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={"name": "Ada", "extra": "ignored"}, - inputs={"#node-0.result#": "LLM output"}, - action="approve", - ) - - service._build_human_input_variable_pool.assert_called_once_with( - app_model=app_model, - workflow=workflow, - node_config=node_config, - manual_inputs={"#node-0.result#": "LLM output"}, - user_id="account-1", - ) - - node.render_form_content_with_outputs.assert_called_once() - called_args = node.render_form_content_with_outputs.call_args.args - assert called_args[0] == "

preview

" - assert called_args[2] == node_data.outputs_field_names() - rendered_outputs = called_args[1] - assert rendered_outputs["name"] == "Ada" - assert rendered_outputs["extra"] == "ignored" - assert "extra" in saved_outputs - assert "extra" in result - assert saved_outputs["name"] == "Ada" - assert result["name"] == "Ada" - assert result["__action_id"] == "approve" - assert "__rendered_content" in result - - def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None: - service = workflow_service - node_data = HumanInputNodeData( - title="Human Input", - form_content="

{{#$output.name#}}

", - inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")], - user_actions=[UserAction(id="approve", title="Approve")], - ) - node = MagicMock() - node.node_data = node_data - node._render_form_content_before_submission.return_value = "

preview

" - node._render_form_content_with_outputs.return_value = "

rendered

" - - service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign] - service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign] - - workflow = MagicMock() - workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.HUMAN_INPUT}} - ) - service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign] - - app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1") - account = SimpleNamespace(id="account-1") - with pytest.raises(ValueError) as exc_info: - service.submit_human_input_form_preview( - app_model=app_model, - account=account, - node_id="node-1", - form_inputs={}, - inputs={}, - action="approve", - ) - - assert "Missing required inputs" in str(exc_info.value) - - def test_run_draft_workflow_node_successful_behavior( - self, workflow_service, mock_app, monkeypatch, dummy_session_cls - ): - """Behavior: When a basic workflow node runs, it correctly sets up context, - executes the node, and saves outputs.""" - service = workflow_service - account = SimpleNamespace(id="account-1") - mock_workflow = MagicMock() - mock_workflow.id = "wf-1" - mock_workflow.tenant_id = "tenant-1" - mock_workflow.environment_variables = [] - mock_workflow.conversation_variables = [] - - # Mock node config - mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} - ) - mock_workflow.get_enclosing_node_type_and_id.return_value = None - - # Mock class methods - monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) - monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) - - # Mock workflow entry execution - mock_node_exec = MagicMock() - mock_node_exec.id = "exec-1" - mock_node_exec.process_data = {} - mock_run = MagicMock() - monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", mock_run) - - # Mock execution handling - service._handle_single_step_result = MagicMock(return_value=mock_node_exec) - - # Mock repository - mock_repo = MagicMock() - mock_repo.get_execution_by_id.return_value = mock_node_exec - mock_repo_factory = MagicMock(return_value=mock_repo) - monkeypatch.setattr( - workflow_service_module.DifyCoreRepositoryFactory, - "create_workflow_node_execution_repository", - mock_repo_factory, - ) - service._node_execution_service_repo = mock_repo - - # Set up node execution service repo mock to return our exec node - mock_node_exec.load_full_outputs.return_value = {"output_var": "result_value"} - mock_node_exec.node_id = "node-1" - mock_node_exec.node_type = "llm" - - # Mock draft variable saver - mock_saver = MagicMock() - monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", MagicMock(return_value=mock_saver)) - - # Mock DB - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - - # Act - result = service.run_draft_workflow_node( - app_model=mock_app, - draft_workflow=mock_workflow, - node_id="node-1", - user_inputs={"input_val": "test"}, - account=account, - ) - - # Assert - assert result == mock_node_exec - service._handle_single_step_result.assert_called_once() - mock_repo.save.assert_called_once_with(mock_node_exec) - mock_saver.save.assert_called_once_with(process_data={}, outputs={"output_var": "result_value"}) - - def test_run_draft_workflow_node_failure_behavior(self, workflow_service, mock_app, monkeypatch, dummy_session_cls): - """Behavior: If retrieving the saved execution fails, an appropriate error bubble matches expectations.""" - service = workflow_service - account = SimpleNamespace(id="account-1") - mock_workflow = MagicMock() - mock_workflow.tenant_id = "tenant-1" - mock_workflow.environment_variables = [] - mock_workflow.conversation_variables = [] - mock_workflow.get_node_config_by_id.return_value = NodeConfigDictAdapter.validate_python( - {"id": "node-1", "data": {"type": BuiltinNodeTypes.LLM}} - ) - mock_workflow.get_enclosing_node_type_and_id.return_value = None - - monkeypatch.setattr(workflow_service_module, "WorkflowDraftVariableService", MagicMock()) - monkeypatch.setattr(workflow_service_module, "DraftVarLoader", MagicMock()) - monkeypatch.setattr(workflow_service_module.WorkflowEntry, "single_step_run", MagicMock()) - - mock_node_exec = MagicMock() - mock_node_exec.id = "exec-invalid" - service._handle_single_step_result = MagicMock(return_value=mock_node_exec) - - mock_repo = MagicMock() - mock_repo_factory = MagicMock(return_value=mock_repo) - monkeypatch.setattr( - workflow_service_module.DifyCoreRepositoryFactory, - "create_workflow_node_execution_repository", - mock_repo_factory, - ) - service._node_execution_service_repo = mock_repo - - # Simulate failure to retrieve the saved execution - mock_repo.get_execution_by_id.return_value = None - - monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock())) - - monkeypatch.setattr(workflow_service_module, "Session", dummy_session_cls) - - # Act & Assert - with pytest.raises(ValueError, match="WorkflowNodeExecution with id exec-invalid not found after saving"): - service.run_draft_workflow_node( - app_model=mock_app, draft_workflow=mock_workflow, node_id="node-1", user_inputs={}, account=account - ) diff --git a/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py new file mode 100644 index 0000000000..b48c69a146 --- /dev/null +++ b/api/tests/unit_tests/tasks/test_enterprise_telemetry_task.py @@ -0,0 +1,69 @@ +"""Unit tests for enterprise telemetry Celery task.""" + +import json +from unittest.mock import MagicMock, patch + +import pytest + +from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope +from tasks.enterprise_telemetry_task import process_enterprise_telemetry + + +@pytest.fixture +def sample_envelope_json(): + envelope = TelemetryEnvelope( + case=TelemetryCase.APP_CREATED, + tenant_id="test-tenant", + event_id="test-event-123", + payload={"app_id": "app-123"}, + ) + return envelope.model_dump_json() + + +def test_process_enterprise_telemetry_success(sample_envelope_json): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + mock_handler.handle.assert_called_once() + call_args = mock_handler.handle.call_args[0][0] + assert isinstance(call_args, TelemetryEnvelope) + assert call_args.case == TelemetryCase.APP_CREATED + assert call_args.tenant_id == "test-tenant" + assert call_args.event_id == "test-event-123" + + +def test_process_enterprise_telemetry_invalid_json(caplog): + invalid_json = "not valid json" + + process_enterprise_telemetry(invalid_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog): + with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class: + mock_handler = MagicMock() + mock_handler.handle.side_effect = Exception("Handler error") + mock_handler_class.return_value = mock_handler + + process_enterprise_telemetry(sample_envelope_json) + + assert "Failed to process enterprise telemetry envelope" in caplog.text + + +def test_process_enterprise_telemetry_validation_error(caplog): + invalid_envelope = json.dumps( + { + "case": "INVALID_CASE", + "tenant_id": "test-tenant", + "event_id": "test-event", + "payload": {}, + } + ) + + process_enterprise_telemetry(invalid_envelope) + + assert "Failed to process enterprise telemetry envelope" in caplog.text diff --git a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py index bd0182a402..7119217e94 100644 --- a/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py +++ b/api/tests/unit_tests/tasks/test_human_input_timeout_tasks.py @@ -5,8 +5,8 @@ from types import SimpleNamespace from typing import Any import pytest +from graphon.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus -from dify_graph.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus from tasks import human_input_timeout_tasks as task_module diff --git a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py index a223f0119e..8cac696d98 100644 --- a/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py +++ b/api/tests/unit_tests/tasks/test_workflow_node_execution_tasks.py @@ -11,11 +11,11 @@ # import pytest -# from dify_graph.entities.workflow_node_execution import ( +# from graphon.entities.workflow_node_execution import ( # WorkflowNodeExecution, # WorkflowNodeExecutionStatus, # ) -# from dify_graph.enums import BuiltinNodeTypes +# from graphon.enums import BuiltinNodeTypes # from libs.datetime_utils import naive_utc_now # from models import WorkflowNodeExecutionModel # from models.enums import ExecutionOffLoadType diff --git a/api/tests/unit_tests/tools/test_mcp_tool.py b/api/tests/unit_tests/tools/test_mcp_tool.py index fa9c6af287..68359ba078 100644 --- a/api/tests/unit_tests/tools/test_mcp_tool.py +++ b/api/tests/unit_tests/tools/test_mcp_tool.py @@ -3,6 +3,7 @@ from decimal import Decimal from unittest.mock import Mock, patch import pytest +from graphon.model_runtime.entities.llm_entities import LLMUsage from core.mcp.types import ( AudioContent, @@ -17,7 +18,6 @@ from core.tools.__base.tool_runtime import ToolRuntime from core.tools.entities.common_entities import I18nObject from core.tools.entities.tool_entities import ToolEntity, ToolIdentity, ToolInvokeMessage from core.tools.mcp_tool.tool import MCPTool -from dify_graph.model_runtime.entities.llm_entities import LLMUsage def _make_mcp_tool(output_schema: dict | None = None) -> MCPTool: diff --git a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py index 7ec1343f98..ffa6833524 100644 --- a/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py +++ b/api/tests/unit_tests/utils/structured_output_parser/test_structured_output_parser.py @@ -2,10 +2,7 @@ from decimal import Decimal from unittest.mock import MagicMock, patch import pytest - -from core.llm_generator.output_parser.errors import OutputParserError -from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output -from dify_graph.model_runtime.entities.llm_entities import ( +from graphon.model_runtime.entities.llm_entities import ( LLMResult, LLMResultChunk, LLMResultChunkDelta, @@ -13,13 +10,16 @@ from dify_graph.model_runtime.entities.llm_entities import ( LLMResultWithStructuredOutput, LLMUsage, ) -from dify_graph.model_runtime.entities.message_entities import ( +from graphon.model_runtime.entities.message_entities import ( AssistantPromptMessage, SystemPromptMessage, TextPromptMessageContent, UserPromptMessage, ) -from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelType +from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType + +from core.llm_generator.output_parser.errors import OutputParserError +from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output def create_mock_usage(prompt_tokens: int = 10, completion_tokens: int = 5) -> LLMUsage: @@ -336,7 +336,6 @@ def test_structured_output_parser(): json_schema=case["json_schema"], stream=case["stream"], model_parameters={"temperature": 0.7, "max_tokens": 100}, - user="test_user", ) if case["expected_result_type"] == "generator": @@ -367,7 +366,7 @@ def test_structured_output_parser(): call_args = model_instance.invoke_llm.call_args assert call_args.kwargs["stream"] == case["stream"] - assert call_args.kwargs["user"] == "test_user" + assert "user" not in call_args.kwargs assert "temperature" in call_args.kwargs["model_parameters"] assert "max_tokens" in call_args.kwargs["model_parameters"] diff --git a/api/tests/workflow_test_utils.py b/api/tests/workflow_test_utils.py index 1f0bf8ef37..d33ac2c710 100644 --- a/api/tests/workflow_test_utils.py +++ b/api/tests/workflow_test_utils.py @@ -1,8 +1,12 @@ from collections.abc import Mapping from typing import Any +from graphon.entities import GraphInitParams +from graphon.runtime import VariablePool +from graphon.variables.variables import Variable + from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context -from dify_graph.entities.graph_init_params import GraphInitParams +from core.workflow.variable_pool_initializer import add_node_inputs_to_pool, add_variables_to_pool def build_test_run_context( @@ -51,3 +55,16 @@ def build_test_graph_init_params( ), call_depth=call_depth, ) + + +def build_test_variable_pool( + *, + variables: list[Variable] | tuple[Variable, ...] = (), + node_id: str | None = None, + inputs: Mapping[str, Any] | None = None, +) -> VariablePool: + variable_pool = VariablePool() + add_variables_to_pool(variable_pool, variables) + if node_id is not None and inputs is not None: + add_node_inputs_to_pool(variable_pool, node_id=node_id, inputs=inputs) + return variable_pool diff --git a/api/uv.lock b/api/uv.lock index 47a3c45df0..c4cf31e3f5 100644 --- a/api/uv.lock +++ b/api/uv.lock @@ -204,7 +204,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/9a/7d/b22cb9a0d4f396ee0 [[package]] name = "alibabacloud-tea-openapi" -version = "0.4.3" +version = "0.4.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "alibabacloud-credentials" }, @@ -213,9 +213,9 @@ dependencies = [ { name = "cryptography" }, { name = "darabonba-core" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/91/4f/b5288eea8f4d4b032c9a8f2cd1d926d5017977d10b874956f31e5343f299/alibabacloud_tea_openapi-0.4.3.tar.gz", hash = "sha256:12aef036ed993637b6f141abbd1de9d6199d5516f4a901588bb65d6a3768d41b", size = 21864, upload-time = "2026-01-15T07:55:16.744Z" } +sdist = { url = "https://files.pythonhosted.org/packages/30/93/138bcdc8fc596add73e37cf2073798f285284d1240bda9ee02f9384fc6be/alibabacloud_tea_openapi-0.4.4.tar.gz", hash = "sha256:1b0917bc03cd49417da64945e92731716d53e2eb8707b235f54e45b7473221ce", size = 21960, upload-time = "2026-03-26T10:16:16.792Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a5/37/48ee5468ecad19c6d44cf3b9629d77078e836ee3ec760f0366247f307b7c/alibabacloud_tea_openapi-0.4.3-py3-none-any.whl", hash = "sha256:d0b3a373b760ef6278b25fc128c73284301e07888977bf97519e7636d47bdf0a", size = 26159, upload-time = "2026-01-15T07:55:15.72Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5a/6bfc4506438c1809c486f66217ad11eab78157192b3d5707b4e2f4212f6c/alibabacloud_tea_openapi-0.4.4-py3-none-any.whl", hash = "sha256:cea6bc1fe35b0319a8752cb99eb0ecb0dab7ca1a71b99c12970ba0867410995f", size = 26236, upload-time = "2026-03-26T10:16:15.861Z" }, ] [[package]] @@ -494,28 +494,29 @@ wheels = [ [[package]] name = "basedpyright" -version = "1.38.3" +version = "1.38.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "nodejs-wheel-binaries" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/0f/58/7abba2c743571a42b2548f07aee556ebc1e4d0bc2b277aeba1ee6c83b0af/basedpyright-1.38.3.tar.gz", hash = "sha256:9725419786afbfad8a9539527f162da02d462afad440b0412fdb3f3cdf179b90", size = 25277430, upload-time = "2026-03-17T13:10:41.526Z" } +sdist = { url = "https://files.pythonhosted.org/packages/08/b4/26cb812eaf8ab56909c792c005fe1690706aef6f21d61107639e46e9c54c/basedpyright-1.38.4.tar.gz", hash = "sha256:8e7d4f37ffb6106621e06b9355025009cdf5b48f71c592432dd2dd304bf55e70", size = 25354730, upload-time = "2026-03-25T13:50:44.353Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/2c/e3/3ebb5c23bd3abb5fc2053b8a06a889aa5c1cf8cff738c78cb6c1957e90cd/basedpyright-1.38.3-py3-none-any.whl", hash = "sha256:1f15c2e489c67d6c5e896c24b6a63251195c04223a55e4568b8f8e8ed49ca830", size = 12313363, upload-time = "2026-03-17T13:10:47.344Z" }, + { url = "https://files.pythonhosted.org/packages/62/0b/3f95fd47def42479e61077523d3752086d5c12009192a7f1c9fd5507e687/basedpyright-1.38.4-py3-none-any.whl", hash = "sha256:90aa067cf3e8a3c17ad5836a72b9e1f046bc72a4ad57d928473d9368c9cd07a2", size = 12352258, upload-time = "2026-03-25T13:50:41.059Z" }, ] [[package]] name = "bce-python-sdk" -version = "0.9.64" +version = "0.9.67" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "crc32c" }, { name = "future" }, { name = "pycryptodome" }, { name = "six" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/61/33/047e9c1a6c97e0cd4d93a6490abd8fbc2ccd13569462fc0228699edc08bc/bce_python_sdk-0.9.64.tar.gz", hash = "sha256:901bf787c26ad35855a80d65e58d7584c8541f7f0f2af20847830e572e5b622e", size = 287125, upload-time = "2026-03-17T11:24:29.345Z" } +sdist = { url = "https://files.pythonhosted.org/packages/b5/b9/5140cc02832fe3a7394c52949796d43f8c1f635aa016100f857f504e0348/bce_python_sdk-0.9.67.tar.gz", hash = "sha256:2c673d757c5c8952f1be6611da4ab77a63ecabaa3ff22b11531f46845ac99e58", size = 295251, upload-time = "2026-03-24T14:10:07.086Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/48/7f/dd289582f37ab4effea47b2a8503880db4781ca0fc8e0a8ed5ff493359e5/bce_python_sdk-0.9.64-py3-none-any.whl", hash = "sha256:eaad97e4f0e7d613ae978da3cdc5294e9f724ffca2735f79820037fa1317cd6d", size = 402233, upload-time = "2026-03-17T11:24:24.673Z" }, + { url = "https://files.pythonhosted.org/packages/d4/a9/a58a63e2756e5d01901595af58c673f68de7621f28d71007479e00f45a6c/bce_python_sdk-0.9.67-py3-none-any.whl", hash = "sha256:3054879d098a92ceeb4b9ac1e64d2c658120a5a10e8e630f22410564b2170bf0", size = 410854, upload-time = "2026-03-24T14:09:54.29Z" }, ] [[package]] @@ -630,30 +631,30 @@ wheels = [ [[package]] name = "boto3" -version = "1.42.73" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore" }, { name = "jmespath" }, { name = "s3transfer" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e4/8b/d00575be514744ca4839e7d85bf4a8a3c7b6b4574433291e58d14c68ae09/boto3-1.42.73.tar.gz", hash = "sha256:d37b58d6cd452ca808dd6823ae19ca65b6244096c5125ef9052988b337298bae", size = 112775, upload-time = "2026-03-20T19:39:52.814Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a8/2b/ebdad075934cf6bb78bf81fe31d83339bcd804ad6c856f7341376cbc88b6/boto3-1.42.78.tar.gz", hash = "sha256:cef2ebdb9be5c0e96822f8d3941ac4b816c90a5737a7ffb901d664c808964b63", size = 112789, upload-time = "2026-03-27T19:28:07.58Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/aa/05/1fcf03d90abaa3d0b42a6bfd10231dd709493ecbacf794aa2eea5eae6841/boto3-1.42.73-py3-none-any.whl", hash = "sha256:1f81b79b873f130eeab14bb556417a7c66d38f3396b7f2fe3b958b3f9094f455", size = 140556, upload-time = "2026-03-20T19:39:50.298Z" }, + { url = "https://files.pythonhosted.org/packages/57/bb/1f6dade1f1e86858bef7bd332bc8106c445f2dbabec7b32ab5d7d118c9b6/boto3-1.42.78-py3-none-any.whl", hash = "sha256:480a34a077484a5ca60124dfd150ba3ea6517fc89963a679e45b30c6db614d26", size = 140556, upload-time = "2026-03-27T19:28:06.125Z" }, ] [[package]] name = "boto3-stubs" -version = "1.42.73" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "botocore-stubs" }, { name = "types-s3transfer" }, { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/b9/c3/fcc47102c63278af25ad57c93d97dc393f4dbc54c0117a29c78f2b96ec1e/boto3_stubs-1.42.73.tar.gz", hash = "sha256:36f625769b5505c4bc627f16244b98de9e10dae3ac36f1aa0f0ebe2f201dc138", size = 101373, upload-time = "2026-03-20T19:59:51.463Z" } +sdist = { url = "https://files.pythonhosted.org/packages/03/16/4bdb3c1f69bf7b97dd8b22fe5b007e9da67ba3f00ed10e47146f5fd9d0ff/boto3_stubs-1.42.78.tar.gz", hash = "sha256:423335b8ce9a935e404054978589cdb98d9fa1d4bd46073d6821bf1c3fad8ca7", size = 101602, upload-time = "2026-03-27T19:35:51.149Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/4b/57/d570ba61a2a0c7fe0c8667e41269a0480293cb53e1786d6661a2bd827fc5/boto3_stubs-1.42.73-py3-none-any.whl", hash = "sha256:bd658429069d8215247fc3abc003220cd875c24ab6eda7b3405090408afaacdf", size = 70009, upload-time = "2026-03-20T19:59:43.786Z" }, + { url = "https://files.pythonhosted.org/packages/22/d5/bdedd4951c795899ac5a1f0b88d81b9e2c6333cb87457f2edd11ef3b7b7b/boto3_stubs-1.42.78-py3-none-any.whl", hash = "sha256:6ed07e734174751da8d01031d9ede8d81a88e4338d9e6b00ce7a6bc870075372", size = 70161, upload-time = "2026-03-27T19:35:46.336Z" }, ] [package.optional-dependencies] @@ -663,16 +664,16 @@ bedrock-runtime = [ [[package]] name = "botocore" -version = "1.42.73" +version = "1.42.78" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "jmespath" }, { name = "python-dateutil" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/28/23/0c88ca116ef63b1ae77c901cd5d2095d22a8dbde9e80df74545db4a061b4/botocore-1.42.73.tar.gz", hash = "sha256:575858641e4949aaf2af1ced145b8524529edf006d075877af6b82ff96ad854c", size = 15008008, upload-time = "2026-03-20T19:39:40.082Z" } +sdist = { url = "https://files.pythonhosted.org/packages/67/8e/cdb34c8ca71216d214e049ada2148ee08bcda12b1ac72af3a720dea300ff/botocore-1.42.78.tar.gz", hash = "sha256:61cbd49728e23f68cfd945406ab40044d49abed143362f7ffa4a4f4bd4311791", size = 15023592, upload-time = "2026-03-27T19:27:57.122Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8e/65/971f3d55015f4d133a6ff3ad74cd39f4b8dd8f53f7775a3c2ad378ea5145/botocore-1.42.73-py3-none-any.whl", hash = "sha256:7b62e2a12f7a1b08eb7360eecd23bb16fe3b7ab7f5617cf91b25476c6f86a0fe", size = 14681861, upload-time = "2026-03-20T19:39:35.341Z" }, + { url = "https://files.pythonhosted.org/packages/54/72/94bba1a375d45c685b00e051b56142359547837086a83861d76f6aec26f4/botocore-1.42.78-py3-none-any.whl", hash = "sha256:038ab63c7f898e8b5db58cb6a45e4da56c31dd984e7e995839a3540c735564ea", size = 14701729, upload-time = "2026-03-27T19:27:54.05Z" }, ] [[package]] @@ -1062,7 +1063,7 @@ wheels = [ [[package]] name = "clickhouse-connect" -version = "0.14.1" +version = "0.15.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -1071,24 +1072,24 @@ dependencies = [ { name = "urllib3" }, { name = "zstandard" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/f5/0e/96958db88b6ce6e9d96dc7a836f12c7644934b3a436b04843f19eb8da2db/clickhouse_connect-0.14.1.tar.gz", hash = "sha256:dc107ae9ab7b86409049ae8abe21817543284b438291796d3dd639ad5496a1ab", size = 120093, upload-time = "2026-03-12T15:51:03.606Z" } +sdist = { url = "https://files.pythonhosted.org/packages/ec/59/c0b0a2c2e4c204e5baeca4917a95cc95add651da3cec86ec464a8e54cfa0/clickhouse_connect-0.15.0.tar.gz", hash = "sha256:529fcf072df335d18ae16339d99389190f4bd543067dcdc174541c7a9c622ef5", size = 126344, upload-time = "2026-03-26T18:34:52.316Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/66/b0/04bc82ca70d4dcc35987c83e4ef04f6dec3c29d3cce4cda3523ebf4498dc/clickhouse_connect-0.14.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:f2b1d1acb8f64c3cd9d922d9e8c0b6328238c4a38e084598c86cc95a0edbd8bd", size = 278797, upload-time = "2026-03-12T15:49:34.728Z" }, - { url = "https://files.pythonhosted.org/packages/97/03/f8434ed43946dcab2d8b4ccf8e90b1c6d69abea0fa8b8aaddb1dc9931657/clickhouse_connect-0.14.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:573f3e5a6b49135b711c086050f46510d4738cc09e5a354cc18ef26f8de5cd98", size = 271849, upload-time = "2026-03-12T15:49:35.881Z" }, - { url = "https://files.pythonhosted.org/packages/a0/db/b3665f4d855c780be8d00638d874fc0d62613d1f1c06ffcad7c11a333f06/clickhouse_connect-0.14.1-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:86b28932faab182a312779e5c3cf341abe19d31028a399bda9d8b06b3b9adab4", size = 1090975, upload-time = "2026-03-12T15:49:37.064Z" }, - { url = "https://files.pythonhosted.org/packages/ea/a2/7ba2d9669c5771734573397b034169653cdf3348dc4cc66bd66d8ab18910/clickhouse_connect-0.14.1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:bfc9650906ff96452c2b5676a7e68e8a77a5642504596f8482e0f3c0ccdffbf1", size = 1095899, upload-time = "2026-03-12T15:49:38.36Z" }, - { url = "https://files.pythonhosted.org/packages/e2/f4/0394af37b491ca832610f2ca7a129e85d8d857d40c94a42f2c2e6d3d9481/clickhouse_connect-0.14.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:b379749a962599f9d6ec81e773a3b907ac58b001f4a977e4ac397f6a76fedff2", size = 1077567, upload-time = "2026-03-12T15:49:40.027Z" }, - { url = "https://files.pythonhosted.org/packages/9a/b8/9279a88afac94c262b55cc75aadc6a3e83f7fa1641e618f9060d9d38415f/clickhouse_connect-0.14.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:43ccb5debd13d41b97af81940c0cac01e92d39f17131d984591bedee13439a5d", size = 1100264, upload-time = "2026-03-12T15:49:41.414Z" }, - { url = "https://files.pythonhosted.org/packages/19/36/20e19ab392c211b83c967e275eb46f663853e0b8ce4da89056fda8a35fc6/clickhouse_connect-0.14.1-cp311-cp311-win32.whl", hash = "sha256:13cbe46c04be8e49da4f6aed698f2570a5295d15f498dd5511b4f761d1ef0edc", size = 250488, upload-time = "2026-03-12T15:49:42.649Z" }, - { url = "https://files.pythonhosted.org/packages/9d/3b/74a07e692a21cad4692e72595cdefbd709bd74a9f778c7334d57a98ee548/clickhouse_connect-0.14.1-cp311-cp311-win_amd64.whl", hash = "sha256:7038cf547c542a17a465e062cd837659f46f99c991efcb010a9ea08ce70960ab", size = 268730, upload-time = "2026-03-12T15:49:44.225Z" }, - { url = "https://files.pythonhosted.org/packages/58/9e/d84a14241967b3aa1e657bbbee83e2eee02d3d6df1ebe8edd4ed72cd8643/clickhouse_connect-0.14.1-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:97665169090889a8bc4dbae4a5fc758b91a23e49a8f8ddc1ae993f18f6d71e02", size = 280679, upload-time = "2026-03-12T15:49:45.497Z" }, - { url = "https://files.pythonhosted.org/packages/d8/29/80835a980be6298a7a2ae42d5a14aab0c9c066ecafe1763bc1958a6f6f0f/clickhouse_connect-0.14.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:3ee6b513ca7d83e0f7b46d87bc2e48260316431cb466680e3540400379bcd1db", size = 271570, upload-time = "2026-03-12T15:49:46.721Z" }, - { url = "https://files.pythonhosted.org/packages/8b/bf/25c17cb91d72143742d2b060c6954e8000a7753c1fd21f7bf8b49ef2bd89/clickhouse_connect-0.14.1-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2a0e8a3f46aba99f1c574927d196e12f1ee689e31c41bf0caec86ad3e181abf3", size = 1115637, upload-time = "2026-03-12T15:49:47.921Z" }, - { url = "https://files.pythonhosted.org/packages/2d/5f/5d5df3585d98889aedc55c9eeb2ea90dba27ec4329eee392101619daf0c0/clickhouse_connect-0.14.1-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:25698cddcdd6c2e4ea12dc5c56d6035d77fc99c5d75e96a54123826c36fdd8ae", size = 1131995, upload-time = "2026-03-12T15:49:49.791Z" }, - { url = "https://files.pythonhosted.org/packages/ad/50/acc9f4c6a1d712f2ed11626f8451eff222e841cf0809655362f0e90454b6/clickhouse_connect-0.14.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:29ab49e5cac44b830b58de73d17a7d895f6c362bf67a50134ff405b428774f44", size = 1095380, upload-time = "2026-03-12T15:49:51.388Z" }, - { url = "https://files.pythonhosted.org/packages/08/18/1ef01beee93d243ec9d9c37f0ce62b3083478a5dd7f59cc13279600cd3a5/clickhouse_connect-0.14.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:3cbf7d7a134692bacd68dd5f8661e87f5db94af60db9f3a74bd732596794910a", size = 1127217, upload-time = "2026-03-12T15:49:53.016Z" }, - { url = "https://files.pythonhosted.org/packages/18/e2/b4daee8287dc49eb9918c77b1e57f5644e47008f719b77281bf5fca63f6e/clickhouse_connect-0.14.1-cp312-cp312-win32.whl", hash = "sha256:6f295b66f3e2ed931dd0d3bb80e00ee94c6f4a584b2dc6d998872b2e0ceaa706", size = 250775, upload-time = "2026-03-12T15:49:54.639Z" }, - { url = "https://files.pythonhosted.org/packages/01/c7/7b55d346952fcd8f0f491faca4449f607a04764fd23cada846dc93facb9e/clickhouse_connect-0.14.1-cp312-cp312-win_amd64.whl", hash = "sha256:c6bb2cce37041c90f8a3b1b380665acbaf252f125e401c13ce8f8df105378f69", size = 269353, upload-time = "2026-03-12T15:49:55.854Z" }, + { url = "https://files.pythonhosted.org/packages/83/b0/bf4a169a1b4e5e19f5e884596937ce13855146a3f4b3225228a87701fd18/clickhouse_connect-0.15.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4f0928fdfb408d314c0e5151caf30b1c3bd56c2812ffdbc8d262fb60c0e7ab28", size = 284805, upload-time = "2026-03-26T18:33:18.659Z" }, + { url = "https://files.pythonhosted.org/packages/ec/d5/63dd572db91bd5e1231d7b7dc63591c52ffbbf653a57f9b8449681815976/clickhouse_connect-0.15.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:6486b02825ac87f57811710e5a9a2da8531bb3c88bcb154fd5c7378742a33d66", size = 277846, upload-time = "2026-03-26T18:33:20.171Z" }, + { url = "https://files.pythonhosted.org/packages/e4/d6/192130a807de130945cc451e17c89ac6183625b8028026e5a4a7fc46fa59/clickhouse_connect-0.15.0-cp311-cp311-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:9f2df9c2fd97b40c6493232e0cbf516d8ba268165c6161851ef15f4f1fd0456e", size = 1096969, upload-time = "2026-03-26T18:33:21.728Z" }, + { url = "https://files.pythonhosted.org/packages/32/46/f2895cc4240ef45a2a274d4323f6858c0860034efe6c9a1c7168f1d8cecd/clickhouse_connect-0.15.0-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:a5a349d19c63abb49c884afe0a0387823045831f005451e85c09c032f953f1c1", size = 1101890, upload-time = "2026-03-26T18:33:23.038Z" }, + { url = "https://files.pythonhosted.org/packages/e8/69/dcecbca254b45525ad3fd8294441ac9cf8a8a8bd1fa8fd6b93e241b377a3/clickhouse_connect-0.15.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4d80205cbdbface6d2f35fbd65a6f85caf2b59ec65f2e9dd190f11e335fe7316", size = 1083561, upload-time = "2026-03-26T18:33:24.64Z" }, + { url = "https://files.pythonhosted.org/packages/69/10/21f0cb98453d9710aaeb92f9a9e156e909c1ac72e57210a48b0f615916a7/clickhouse_connect-0.15.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c3c84dfebf49ec7a2cd9ac31c46986f7a81b43ea781d23ef7d607907fcc6de5d", size = 1106257, upload-time = "2026-03-26T18:33:26.257Z" }, + { url = "https://files.pythonhosted.org/packages/70/91/ae0f5c8df5dc650f1ab327d4b40cde7e18bf9e8b3507764dce320c328092/clickhouse_connect-0.15.0-cp311-cp311-win32.whl", hash = "sha256:d2bbdccf9cd838b990576d3f7d1e6a0ab5c3a5c8eb830394258b7b225531fe74", size = 256591, upload-time = "2026-03-26T18:33:27.869Z" }, + { url = "https://files.pythonhosted.org/packages/e6/7f/85673ff522554ef76e17b5d267816c199a731fde836ef957b0960655f251/clickhouse_connect-0.15.0-cp311-cp311-win_amd64.whl", hash = "sha256:1c4223d557bc0a3919cb7ce0d749d9091123b6e61341e028ffc09b7f9c847ac2", size = 274778, upload-time = "2026-03-26T18:33:29.02Z" }, + { url = "https://files.pythonhosted.org/packages/f5/be/86e149c60822caed29e4435acac4fc73e20fddfb0b56ea6452bc7a08ab10/clickhouse_connect-0.15.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:d51f49694e9007564bfd8dac51a1f9e60b94d6c93a07eb4027113a2e62bbb384", size = 286680, upload-time = "2026-03-26T18:33:30.219Z" }, + { url = "https://files.pythonhosted.org/packages/aa/65/c38cc5028afa2ccd9e8ff65611434063c0c5c1b6edadc507dbbc80a09bfd/clickhouse_connect-0.15.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:6a48fbad9ebc2b6d1cd01d1f9b5d6740081f1c84f1aacc9f91651be949f6b6ed", size = 277579, upload-time = "2026-03-26T18:33:31.474Z" }, + { url = "https://files.pythonhosted.org/packages/0a/ef/c8b2ef597fefd04e8b7c017c991552162cb89b7cb73bfdd6225b1c79e2fe/clickhouse_connect-0.15.0-cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:36e1ae470b94cc56d270461c8626c8fd4dac16e6c1ffa8477f21c012462e22cf", size = 1121630, upload-time = "2026-03-26T18:33:32.983Z" }, + { url = "https://files.pythonhosted.org/packages/de/f7/1b71819e825d44582c014a489618170b03ccdac3c9b710dfd56445f1c017/clickhouse_connect-0.15.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:fa97f0ae8eb069a451d8577342dffeef5dc308a0eac7dba1809008c761e720c7", size = 1137988, upload-time = "2026-03-26T18:33:34.585Z" }, + { url = "https://files.pythonhosted.org/packages/7f/1f/41002b8d5ff146dc2835dc6b6f690bc361bd9a94b6195872abcb922f3788/clickhouse_connect-0.15.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:b5b3baf70009174a4df9c8356c96d03e1c2dbf0d8b29f1b3270a641a59399b61", size = 1101376, upload-time = "2026-03-26T18:33:36.258Z" }, + { url = "https://files.pythonhosted.org/packages/2c/8a/bd090dab73fc9c47efcaaeb152a77610b9d233cd88ea73cf4535f9bac2a6/clickhouse_connect-0.15.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:af3fba93fd2efa8f856f3a88a6a710e06005fa48b6b6b0f116d462a4021957e2", size = 1133211, upload-time = "2026-03-26T18:33:38.003Z" }, + { url = "https://files.pythonhosted.org/packages/f1/8d/cf4eee7225bdee85a9b8a88c5bfff42ce48f37ee9277930ac8bc76f47126/clickhouse_connect-0.15.0-cp312-cp312-win32.whl", hash = "sha256:86ca76f8acaf7f3f6530e3e4139e174d54c4674910c69f4277d1b9cdf7c1cc98", size = 256767, upload-time = "2026-03-26T18:33:39.55Z" }, + { url = "https://files.pythonhosted.org/packages/26/6e/f5a2cb1e4624dfd77c1e226239360a9e3690db8056a0027bda2ab87d0085/clickhouse_connect-0.15.0-cp312-cp312-win_amd64.whl", hash = "sha256:5a471d9a9cf06f0a4e90784547b6a2acb066b0d8642dfea9866960c4bdde6959", size = 275404, upload-time = "2026-03-26T18:33:40.885Z" }, ] [[package]] @@ -1308,43 +1309,47 @@ wheels = [ [[package]] name = "cryptography" -version = "44.0.3" +version = "46.0.6" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cffi", marker = "platform_python_implementation != 'PyPy'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/53/d6/1411ab4d6108ab167d06254c5be517681f1e331f90edf1379895bcb87020/cryptography-44.0.3.tar.gz", hash = "sha256:fe19d8bc5536a91a24a8133328880a41831b6c5df54599a8417b62fe015d3053", size = 711096, upload-time = "2025-05-02T19:36:04.667Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a4/ba/04b1bd4218cbc58dc90ce967106d51582371b898690f3ae0402876cc4f34/cryptography-46.0.6.tar.gz", hash = "sha256:27550628a518c5c6c903d84f637fbecf287f6cb9ced3804838a1295dc1fd0759", size = 750542, upload-time = "2026-03-25T23:34:53.396Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/08/53/c776d80e9d26441bb3868457909b4e74dd9ccabd182e10b2b0ae7a07e265/cryptography-44.0.3-cp37-abi3-macosx_10_9_universal2.whl", hash = "sha256:962bc30480a08d133e631e8dfd4783ab71cc9e33d5d7c1e192f0b7c06397bb88", size = 6670281, upload-time = "2025-05-02T19:34:50.665Z" }, - { url = "https://files.pythonhosted.org/packages/6a/06/af2cf8d56ef87c77319e9086601bef621bedf40f6f59069e1b6d1ec498c5/cryptography-44.0.3-cp37-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ffc61e8f3bf5b60346d89cd3d37231019c17a081208dfbbd6e1605ba03fa137", size = 3959305, upload-time = "2025-05-02T19:34:53.042Z" }, - { url = "https://files.pythonhosted.org/packages/ae/01/80de3bec64627207d030f47bf3536889efee8913cd363e78ca9a09b13c8e/cryptography-44.0.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:58968d331425a6f9eedcee087f77fd3c927c88f55368f43ff7e0a19891f2642c", size = 4171040, upload-time = "2025-05-02T19:34:54.675Z" }, - { url = "https://files.pythonhosted.org/packages/bd/48/bb16b7541d207a19d9ae8b541c70037a05e473ddc72ccb1386524d4f023c/cryptography-44.0.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:e28d62e59a4dbd1d22e747f57d4f00c459af22181f0b2f787ea83f5a876d7c76", size = 3963411, upload-time = "2025-05-02T19:34:56.61Z" }, - { url = "https://files.pythonhosted.org/packages/42/b2/7d31f2af5591d217d71d37d044ef5412945a8a8e98d5a2a8ae4fd9cd4489/cryptography-44.0.3-cp37-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:af653022a0c25ef2e3ffb2c673a50e5a0d02fecc41608f4954176f1933b12359", size = 3689263, upload-time = "2025-05-02T19:34:58.591Z" }, - { url = "https://files.pythonhosted.org/packages/25/50/c0dfb9d87ae88ccc01aad8eb93e23cfbcea6a6a106a9b63a7b14c1f93c75/cryptography-44.0.3-cp37-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:157f1f3b8d941c2bd8f3ffee0af9b049c9665c39d3da9db2dc338feca5e98a43", size = 4196198, upload-time = "2025-05-02T19:35:00.988Z" }, - { url = "https://files.pythonhosted.org/packages/66/c9/55c6b8794a74da652690c898cb43906310a3e4e4f6ee0b5f8b3b3e70c441/cryptography-44.0.3-cp37-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:c6cd67722619e4d55fdb42ead64ed8843d64638e9c07f4011163e46bc512cf01", size = 3966502, upload-time = "2025-05-02T19:35:03.091Z" }, - { url = "https://files.pythonhosted.org/packages/b6/f7/7cb5488c682ca59a02a32ec5f975074084db4c983f849d47b7b67cc8697a/cryptography-44.0.3-cp37-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:b424563394c369a804ecbee9b06dfb34997f19d00b3518e39f83a5642618397d", size = 4196173, upload-time = "2025-05-02T19:35:05.018Z" }, - { url = "https://files.pythonhosted.org/packages/d2/0b/2f789a8403ae089b0b121f8f54f4a3e5228df756e2146efdf4a09a3d5083/cryptography-44.0.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:c91fc8e8fd78af553f98bc7f2a1d8db977334e4eea302a4bfd75b9461c2d8904", size = 4087713, upload-time = "2025-05-02T19:35:07.187Z" }, - { url = "https://files.pythonhosted.org/packages/1d/aa/330c13655f1af398fc154089295cf259252f0ba5df93b4bc9d9c7d7f843e/cryptography-44.0.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:25cd194c39fa5a0aa4169125ee27d1172097857b27109a45fadc59653ec06f44", size = 4299064, upload-time = "2025-05-02T19:35:08.879Z" }, - { url = "https://files.pythonhosted.org/packages/10/a8/8c540a421b44fd267a7d58a1fd5f072a552d72204a3f08194f98889de76d/cryptography-44.0.3-cp37-abi3-win32.whl", hash = "sha256:3be3f649d91cb182c3a6bd336de8b61a0a71965bd13d1a04a0e15b39c3d5809d", size = 2773887, upload-time = "2025-05-02T19:35:10.41Z" }, - { url = "https://files.pythonhosted.org/packages/b9/0d/c4b1657c39ead18d76bbd122da86bd95bdc4095413460d09544000a17d56/cryptography-44.0.3-cp37-abi3-win_amd64.whl", hash = "sha256:3883076d5c4cc56dbef0b898a74eb6992fdac29a7b9013870b34efe4ddb39a0d", size = 3209737, upload-time = "2025-05-02T19:35:12.12Z" }, - { url = "https://files.pythonhosted.org/packages/34/a3/ad08e0bcc34ad436013458d7528e83ac29910943cea42ad7dd4141a27bbb/cryptography-44.0.3-cp39-abi3-macosx_10_9_universal2.whl", hash = "sha256:5639c2b16764c6f76eedf722dbad9a0914960d3489c0cc38694ddf9464f1bb2f", size = 6673501, upload-time = "2025-05-02T19:35:13.775Z" }, - { url = "https://files.pythonhosted.org/packages/b1/f0/7491d44bba8d28b464a5bc8cc709f25a51e3eac54c0a4444cf2473a57c37/cryptography-44.0.3-cp39-abi3-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f3ffef566ac88f75967d7abd852ed5f182da252d23fac11b4766da3957766759", size = 3960307, upload-time = "2025-05-02T19:35:15.917Z" }, - { url = "https://files.pythonhosted.org/packages/f7/c8/e5c5d0e1364d3346a5747cdcd7ecbb23ca87e6dea4f942a44e88be349f06/cryptography-44.0.3-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:192ed30fac1728f7587c6f4613c29c584abdc565d7417c13904708db10206645", size = 4170876, upload-time = "2025-05-02T19:35:18.138Z" }, - { url = "https://files.pythonhosted.org/packages/73/96/025cb26fc351d8c7d3a1c44e20cf9a01e9f7cf740353c9c7a17072e4b264/cryptography-44.0.3-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:7d5fe7195c27c32a64955740b949070f21cba664604291c298518d2e255931d2", size = 3964127, upload-time = "2025-05-02T19:35:19.864Z" }, - { url = "https://files.pythonhosted.org/packages/01/44/eb6522db7d9f84e8833ba3bf63313f8e257729cf3a8917379473fcfd6601/cryptography-44.0.3-cp39-abi3-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:3f07943aa4d7dad689e3bb1638ddc4944cc5e0921e3c227486daae0e31a05e54", size = 3689164, upload-time = "2025-05-02T19:35:21.449Z" }, - { url = "https://files.pythonhosted.org/packages/68/fb/d61a4defd0d6cee20b1b8a1ea8f5e25007e26aeb413ca53835f0cae2bcd1/cryptography-44.0.3-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:cb90f60e03d563ca2445099edf605c16ed1d5b15182d21831f58460c48bffb93", size = 4198081, upload-time = "2025-05-02T19:35:23.187Z" }, - { url = "https://files.pythonhosted.org/packages/1b/50/457f6911d36432a8811c3ab8bd5a6090e8d18ce655c22820994913dd06ea/cryptography-44.0.3-cp39-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:ab0b005721cc0039e885ac3503825661bd9810b15d4f374e473f8c89b7d5460c", size = 3967716, upload-time = "2025-05-02T19:35:25.426Z" }, - { url = "https://files.pythonhosted.org/packages/35/6e/dca39d553075980ccb631955c47b93d87d27f3596da8d48b1ae81463d915/cryptography-44.0.3-cp39-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:3bb0847e6363c037df8f6ede57d88eaf3410ca2267fb12275370a76f85786a6f", size = 4197398, upload-time = "2025-05-02T19:35:27.678Z" }, - { url = "https://files.pythonhosted.org/packages/9b/9d/d1f2fe681eabc682067c66a74addd46c887ebacf39038ba01f8860338d3d/cryptography-44.0.3-cp39-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0cc66c74c797e1db750aaa842ad5b8b78e14805a9b5d1348dc603612d3e3ff5", size = 4087900, upload-time = "2025-05-02T19:35:29.312Z" }, - { url = "https://files.pythonhosted.org/packages/c4/f5/3599e48c5464580b73b236aafb20973b953cd2e7b44c7c2533de1d888446/cryptography-44.0.3-cp39-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6866df152b581f9429020320e5eb9794c8780e90f7ccb021940d7f50ee00ae0b", size = 4301067, upload-time = "2025-05-02T19:35:31.547Z" }, - { url = "https://files.pythonhosted.org/packages/a7/6c/d2c48c8137eb39d0c193274db5c04a75dab20d2f7c3f81a7dcc3a8897701/cryptography-44.0.3-cp39-abi3-win32.whl", hash = "sha256:c138abae3a12a94c75c10499f1cbae81294a6f983b3af066390adee73f433028", size = 2775467, upload-time = "2025-05-02T19:35:33.805Z" }, - { url = "https://files.pythonhosted.org/packages/c9/ad/51f212198681ea7b0deaaf8846ee10af99fba4e894f67b353524eab2bbe5/cryptography-44.0.3-cp39-abi3-win_amd64.whl", hash = "sha256:5d186f32e52e66994dce4f766884bcb9c68b8da62d61d9d215bfe5fb56d21334", size = 3210375, upload-time = "2025-05-02T19:35:35.369Z" }, - { url = "https://files.pythonhosted.org/packages/8d/4b/c11ad0b6c061902de5223892d680e89c06c7c4d606305eb8de56c5427ae6/cryptography-44.0.3-pp311-pypy311_pp73-macosx_10_9_x86_64.whl", hash = "sha256:896530bc9107b226f265effa7ef3f21270f18a2026bc09fed1ebd7b66ddf6375", size = 3390230, upload-time = "2025-05-02T19:35:49.062Z" }, - { url = "https://files.pythonhosted.org/packages/58/11/0a6bf45d53b9b2290ea3cec30e78b78e6ca29dc101e2e296872a0ffe1335/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:9b4d4a5dbee05a2c390bf212e78b99434efec37b17a4bff42f50285c5c8c9647", size = 3895216, upload-time = "2025-05-02T19:35:51.351Z" }, - { url = "https://files.pythonhosted.org/packages/0a/27/b28cdeb7270e957f0077a2c2bfad1b38f72f1f6d699679f97b816ca33642/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:02f55fb4f8b79c1221b0961488eaae21015b69b210e18c386b69de182ebb1259", size = 4115044, upload-time = "2025-05-02T19:35:53.044Z" }, - { url = "https://files.pythonhosted.org/packages/35/b0/ec4082d3793f03cb248881fecefc26015813199b88f33e3e990a43f79835/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:dd3db61b8fe5be220eee484a17233287d0be6932d056cf5738225b9c05ef4fff", size = 3898034, upload-time = "2025-05-02T19:35:54.72Z" }, - { url = "https://files.pythonhosted.org/packages/0b/7f/adf62e0b8e8d04d50c9a91282a57628c00c54d4ae75e2b02a223bd1f2613/cryptography-44.0.3-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:978631ec51a6bbc0b7e58f23b68a8ce9e5f09721940933e9c217068388789fe5", size = 4114449, upload-time = "2025-05-02T19:35:57.139Z" }, - { url = "https://files.pythonhosted.org/packages/87/62/d69eb4a8ee231f4bf733a92caf9da13f1c81a44e874b1d4080c25ecbb723/cryptography-44.0.3-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:5d20cc348cca3a8aa7312f42ab953a56e15323800ca3ab0706b8cd452a3a056c", size = 3134369, upload-time = "2025-05-02T19:35:58.907Z" }, + { url = "https://files.pythonhosted.org/packages/47/23/9285e15e3bc57325b0a72e592921983a701efc1ee8f91c06c5f0235d86d9/cryptography-46.0.6-cp311-abi3-macosx_10_9_universal2.whl", hash = "sha256:64235194bad039a10bb6d2d930ab3323baaec67e2ce36215fd0952fad0930ca8", size = 7176401, upload-time = "2026-03-25T23:33:22.096Z" }, + { url = "https://files.pythonhosted.org/packages/60/f8/e61f8f13950ab6195b31913b42d39f0f9afc7d93f76710f299b5ec286ae6/cryptography-46.0.6-cp311-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:26031f1e5ca62fcb9d1fcb34b2b60b390d1aacaa15dc8b895a9ed00968b97b30", size = 4275275, upload-time = "2026-03-25T23:33:23.844Z" }, + { url = "https://files.pythonhosted.org/packages/19/69/732a736d12c2631e140be2348b4ad3d226302df63ef64d30dfdb8db7ad1c/cryptography-46.0.6-cp311-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:9a693028b9cbe51b5a1136232ee8f2bc242e4e19d456ded3fa7c86e43c713b4a", size = 4425320, upload-time = "2026-03-25T23:33:25.703Z" }, + { url = "https://files.pythonhosted.org/packages/d4/12/123be7292674abf76b21ac1fc0e1af50661f0e5b8f0ec8285faac18eb99e/cryptography-46.0.6-cp311-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:67177e8a9f421aa2d3a170c3e56eca4e0128883cf52a071a7cbf53297f18b175", size = 4278082, upload-time = "2026-03-25T23:33:27.423Z" }, + { url = "https://files.pythonhosted.org/packages/5b/ba/d5e27f8d68c24951b0a484924a84c7cdaed7502bac9f18601cd357f8b1d2/cryptography-46.0.6-cp311-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:d9528b535a6c4f8ff37847144b8986a9a143585f0540fbcb1a98115b543aa463", size = 4926514, upload-time = "2026-03-25T23:33:29.206Z" }, + { url = "https://files.pythonhosted.org/packages/34/71/1ea5a7352ae516d5512d17babe7e1b87d9db5150b21f794b1377eac1edc0/cryptography-46.0.6-cp311-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:22259338084d6ae497a19bae5d4c66b7ca1387d3264d1c2c0e72d9e9b6a77b97", size = 4457766, upload-time = "2026-03-25T23:33:30.834Z" }, + { url = "https://files.pythonhosted.org/packages/01/59/562be1e653accee4fdad92c7a2e88fced26b3fdfce144047519bbebc299e/cryptography-46.0.6-cp311-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:760997a4b950ff00d418398ad73fbc91aa2894b5c1db7ccb45b4f68b42a63b3c", size = 3986535, upload-time = "2026-03-25T23:33:33.02Z" }, + { url = "https://files.pythonhosted.org/packages/d6/8b/b1ebfeb788bf4624d36e45ed2662b8bd43a05ff62157093c1539c1288a18/cryptography-46.0.6-cp311-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:3dfa6567f2e9e4c5dceb8ccb5a708158a2a871052fa75c8b78cb0977063f1507", size = 4277618, upload-time = "2026-03-25T23:33:34.567Z" }, + { url = "https://files.pythonhosted.org/packages/dd/52/a005f8eabdb28df57c20f84c44d397a755782d6ff6d455f05baa2785bd91/cryptography-46.0.6-cp311-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:cdcd3edcbc5d55757e5f5f3d330dd00007ae463a7e7aa5bf132d1f22a4b62b19", size = 4890802, upload-time = "2026-03-25T23:33:37.034Z" }, + { url = "https://files.pythonhosted.org/packages/ec/4d/8e7d7245c79c617d08724e2efa397737715ca0ec830ecb3c91e547302555/cryptography-46.0.6-cp311-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:d4e4aadb7fc1f88687f47ca20bb7227981b03afaae69287029da08096853b738", size = 4457425, upload-time = "2026-03-25T23:33:38.904Z" }, + { url = "https://files.pythonhosted.org/packages/1d/5c/f6c3596a1430cec6f949085f0e1a970638d76f81c3ea56d93d564d04c340/cryptography-46.0.6-cp311-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:2b417edbe8877cda9022dde3a008e2deb50be9c407eef034aeeb3a8b11d9db3c", size = 4405530, upload-time = "2026-03-25T23:33:40.842Z" }, + { url = "https://files.pythonhosted.org/packages/7e/c9/9f9cea13ee2dbde070424e0c4f621c091a91ffcc504ffea5e74f0e1daeff/cryptography-46.0.6-cp311-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:380343e0653b1c9d7e1f55b52aaa2dbb2fdf2730088d48c43ca1c7c0abb7cc2f", size = 4667896, upload-time = "2026-03-25T23:33:42.781Z" }, + { url = "https://files.pythonhosted.org/packages/ad/b5/1895bc0821226f129bc74d00eccfc6a5969e2028f8617c09790bf89c185e/cryptography-46.0.6-cp311-abi3-win32.whl", hash = "sha256:bcb87663e1f7b075e48c3be3ecb5f0b46c8fc50b50a97cf264e7f60242dca3f2", size = 3026348, upload-time = "2026-03-25T23:33:45.021Z" }, + { url = "https://files.pythonhosted.org/packages/c3/f8/c9bcbf0d3e6ad288b9d9aa0b1dee04b063d19e8c4f871855a03ab3a297ab/cryptography-46.0.6-cp311-abi3-win_amd64.whl", hash = "sha256:6739d56300662c468fddb0e5e291f9b4d084bead381667b9e654c7dd81705124", size = 3483896, upload-time = "2026-03-25T23:33:46.649Z" }, + { url = "https://files.pythonhosted.org/packages/c4/cc/f330e982852403da79008552de9906804568ae9230da8432f7496ce02b71/cryptography-46.0.6-cp38-abi3-macosx_10_9_universal2.whl", hash = "sha256:12cae594e9473bca1a7aceb90536060643128bb274fcea0fc459ab90f7d1ae7a", size = 7162776, upload-time = "2026-03-25T23:34:13.308Z" }, + { url = "https://files.pythonhosted.org/packages/49/b3/dc27efd8dcc4bff583b3f01d4a3943cd8b5821777a58b3a6a5f054d61b79/cryptography-46.0.6-cp38-abi3-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:639301950939d844a9e1c4464d7e07f902fe9a7f6b215bb0d4f28584729935d8", size = 4270529, upload-time = "2026-03-25T23:34:15.019Z" }, + { url = "https://files.pythonhosted.org/packages/e6/05/e8d0e6eb4f0d83365b3cb0e00eb3c484f7348db0266652ccd84632a3d58d/cryptography-46.0.6-cp38-abi3-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:ed3775295fb91f70b4027aeba878d79b3e55c0b3e97eaa4de71f8f23a9f2eb77", size = 4414827, upload-time = "2026-03-25T23:34:16.604Z" }, + { url = "https://files.pythonhosted.org/packages/2f/97/daba0f5d2dc6d855e2dcb70733c812558a7977a55dd4a6722756628c44d1/cryptography-46.0.6-cp38-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8927ccfbe967c7df312ade694f987e7e9e22b2425976ddbf28271d7e58845290", size = 4271265, upload-time = "2026-03-25T23:34:18.586Z" }, + { url = "https://files.pythonhosted.org/packages/89/06/fe1fce39a37ac452e58d04b43b0855261dac320a2ebf8f5260dd55b201a9/cryptography-46.0.6-cp38-abi3-manylinux_2_28_ppc64le.whl", hash = "sha256:b12c6b1e1651e42ab5de8b1e00dc3b6354fdfd778e7fa60541ddacc27cd21410", size = 4916800, upload-time = "2026-03-25T23:34:20.561Z" }, + { url = "https://files.pythonhosted.org/packages/ff/8a/b14f3101fe9c3592603339eb5d94046c3ce5f7fc76d6512a2d40efd9724e/cryptography-46.0.6-cp38-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:063b67749f338ca9c5a0b7fe438a52c25f9526b851e24e6c9310e7195aad3b4d", size = 4448771, upload-time = "2026-03-25T23:34:22.406Z" }, + { url = "https://files.pythonhosted.org/packages/01/b3/0796998056a66d1973fd52ee89dc1bb3b6581960a91ad4ac705f182d398f/cryptography-46.0.6-cp38-abi3-manylinux_2_31_armv7l.whl", hash = "sha256:02fad249cb0e090b574e30b276a3da6a149e04ee2f049725b1f69e7b8351ec70", size = 3978333, upload-time = "2026-03-25T23:34:24.281Z" }, + { url = "https://files.pythonhosted.org/packages/c5/3d/db200af5a4ffd08918cd55c08399dc6c9c50b0bc72c00a3246e099d3a849/cryptography-46.0.6-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:7e6142674f2a9291463e5e150090b95a8519b2fb6e6aaec8917dd8d094ce750d", size = 4271069, upload-time = "2026-03-25T23:34:25.895Z" }, + { url = "https://files.pythonhosted.org/packages/d7/18/61acfd5b414309d74ee838be321c636fe71815436f53c9f0334bf19064fa/cryptography-46.0.6-cp38-abi3-manylinux_2_34_ppc64le.whl", hash = "sha256:456b3215172aeefb9284550b162801d62f5f264a081049a3e94307fe20792cfa", size = 4878358, upload-time = "2026-03-25T23:34:27.67Z" }, + { url = "https://files.pythonhosted.org/packages/8b/65/5bf43286d566f8171917cae23ac6add941654ccf085d739195a4eacf1674/cryptography-46.0.6-cp38-abi3-manylinux_2_34_x86_64.whl", hash = "sha256:341359d6c9e68834e204ceaf25936dffeafea3829ab80e9503860dcc4f4dac58", size = 4448061, upload-time = "2026-03-25T23:34:29.375Z" }, + { url = "https://files.pythonhosted.org/packages/e0/25/7e49c0fa7205cf3597e525d156a6bce5b5c9de1fd7e8cb01120e459f205a/cryptography-46.0.6-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:9a9c42a2723999a710445bc0d974e345c32adfd8d2fac6d8a251fa829ad31cfb", size = 4399103, upload-time = "2026-03-25T23:34:32.036Z" }, + { url = "https://files.pythonhosted.org/packages/44/46/466269e833f1c4718d6cd496ffe20c56c9c8d013486ff66b4f69c302a68d/cryptography-46.0.6-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6617f67b1606dfd9fe4dbfa354a9508d4a6d37afe30306fe6c101b7ce3274b72", size = 4659255, upload-time = "2026-03-25T23:34:33.679Z" }, + { url = "https://files.pythonhosted.org/packages/0a/09/ddc5f630cc32287d2c953fc5d32705e63ec73e37308e5120955316f53827/cryptography-46.0.6-cp38-abi3-win32.whl", hash = "sha256:7f6690b6c55e9c5332c0b59b9c8a3fb232ebf059094c17f9019a51e9827df91c", size = 3010660, upload-time = "2026-03-25T23:34:35.418Z" }, + { url = "https://files.pythonhosted.org/packages/1b/82/ca4893968aeb2709aacfb57a30dec6fa2ab25b10fa9f064b8882ce33f599/cryptography-46.0.6-cp38-abi3-win_amd64.whl", hash = "sha256:79e865c642cfc5c0b3eb12af83c35c5aeff4fa5c672dc28c43721c2c9fdd2f0f", size = 3471160, upload-time = "2026-03-25T23:34:37.191Z" }, + { url = "https://files.pythonhosted.org/packages/2e/84/7ccff00ced5bac74b775ce0beb7d1be4e8637536b522b5df9b73ada42da2/cryptography-46.0.6-pp311-pypy311_pp73-macosx_11_0_arm64.whl", hash = "sha256:2ea0f37e9a9cf0df2952893ad145fd9627d326a59daec9b0802480fa3bcd2ead", size = 3475444, upload-time = "2026-03-25T23:34:38.944Z" }, + { url = "https://files.pythonhosted.org/packages/bc/1f/4c926f50df7749f000f20eede0c896769509895e2648db5da0ed55db711d/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_aarch64.whl", hash = "sha256:a3e84d5ec9ba01f8fd03802b2147ba77f0c8f2617b2aff254cedd551844209c8", size = 4218227, upload-time = "2026-03-25T23:34:40.871Z" }, + { url = "https://files.pythonhosted.org/packages/c6/65/707be3ffbd5f786028665c3223e86e11c4cda86023adbc56bd72b1b6bab5/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:12f0fa16cc247b13c43d56d7b35287ff1569b5b1f4c5e87e92cc4fcc00cd10c0", size = 4381399, upload-time = "2026-03-25T23:34:42.609Z" }, + { url = "https://files.pythonhosted.org/packages/f3/6d/73557ed0ef7d73d04d9aba745d2c8e95218213687ee5e76b7d236a5030fc/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_aarch64.whl", hash = "sha256:50575a76e2951fe7dbd1f56d181f8c5ceeeb075e9ff88e7ad997d2f42af06e7b", size = 4217595, upload-time = "2026-03-25T23:34:44.205Z" }, + { url = "https://files.pythonhosted.org/packages/9e/c5/e1594c4eec66a567c3ac4400008108a415808be2ce13dcb9a9045c92f1a0/cryptography-46.0.6-pp311-pypy311_pp73-manylinux_2_34_x86_64.whl", hash = "sha256:90e5f0a7b3be5f40c3a0a0eafb32c681d8d2c181fc2a1bdabe9b3f611d9f6b1a", size = 4380912, upload-time = "2026-03-25T23:34:46.328Z" }, + { url = "https://files.pythonhosted.org/packages/1a/89/843b53614b47f97fe1abc13f9a86efa5ec9e275292c457af1d4a60dc80e0/cryptography-46.0.6-pp311-pypy311_pp73-win_amd64.whl", hash = "sha256:6728c49e3b2c180ef26f8e9f0a883a2c585638db64cf265b49c9ba10652d430e", size = 3409955, upload-time = "2026-03-25T23:34:48.465Z" }, ] [[package]] @@ -1457,7 +1462,7 @@ wheels = [ [[package]] name = "dify-api" -version = "1.13.2" +version = "1.13.3" source = { virtual = "." } dependencies = [ { name = "aliyun-log-python-sdk" }, @@ -1489,12 +1494,12 @@ dependencies = [ { name = "google-auth-httplib2" }, { name = "google-cloud-aiplatform" }, { name = "googleapis-common-protos" }, + { name = "graphon" }, { name = "gunicorn" }, { name = "httpx", extra = ["socks"] }, { name = "httpx-sse" }, { name = "jieba" }, { name = "json-repair" }, - { name = "jsonschema" }, { name = "langfuse" }, { name = "langsmith" }, { name = "litellm" }, @@ -1526,7 +1531,6 @@ dependencies = [ { name = "psycopg2-binary" }, { name = "pycryptodome" }, { name = "pydantic" }, - { name = "pydantic-extra-types" }, { name = "pydantic-settings" }, { name = "pyjwt" }, { name = "pypandoc" }, @@ -1547,7 +1551,6 @@ dependencies = [ { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, { name = "weave" }, { name = "weaviate-client" }, - { name = "webvtt-py" }, { name = "yarl" }, ] @@ -1590,7 +1593,6 @@ dev = [ { name = "types-greenlet" }, { name = "types-html5lib" }, { name = "types-jmespath" }, - { name = "types-jsonschema" }, { name = "types-markdown" }, { name = "types-oauthlib" }, { name = "types-objgraph" }, @@ -1669,7 +1671,7 @@ requires-dist = [ { name = "azure-identity", specifier = "==1.25.3" }, { name = "beautifulsoup4", specifier = "==4.14.3" }, { name = "bleach", specifier = "~=6.3.0" }, - { name = "boto3", specifier = "==1.42.73" }, + { name = "boto3", specifier = "==1.42.78" }, { name = "bs4", specifier = "~=0.0.1" }, { name = "cachetools", specifier = "~=5.3.0" }, { name = "celery", specifier = "~=5.6.2" }, @@ -1692,12 +1694,12 @@ requires-dist = [ { name = "google-auth-httplib2", specifier = "==0.3.0" }, { name = "google-cloud-aiplatform", specifier = ">=1.123.0" }, { name = "googleapis-common-protos", specifier = ">=1.65.0" }, - { name = "gunicorn", specifier = "~=25.1.0" }, + { name = "graphon", specifier = ">=0.1.2" }, + { name = "gunicorn", specifier = "~=25.3.0" }, { name = "httpx", extras = ["socks"], specifier = "~=0.28.0" }, { name = "httpx-sse", specifier = "~=0.4.0" }, { name = "jieba", specifier = "==0.42.1" }, { name = "json-repair", specifier = ">=0.55.1" }, - { name = "jsonschema", specifier = ">=4.25.1" }, { name = "langfuse", specifier = "~=2.51.3" }, { name = "langsmith", specifier = "~=0.7.16" }, { name = "litellm", specifier = "==1.82.6" }, @@ -1729,7 +1731,6 @@ requires-dist = [ { name = "psycopg2-binary", specifier = "~=2.9.6" }, { name = "pycryptodome", specifier = "==3.23.0" }, { name = "pydantic", specifier = "~=2.12.5" }, - { name = "pydantic-extra-types", specifier = "~=2.11.0" }, { name = "pydantic-settings", specifier = "~=2.13.1" }, { name = "pyjwt", specifier = "~=2.12.0" }, { name = "pypandoc", specifier = "~=1.13" }, @@ -1738,7 +1739,7 @@ requires-dist = [ { name = "python-dotenv", specifier = "==1.2.2" }, { name = "pyyaml", specifier = "~=6.0.1" }, { name = "readabilipy", specifier = "~=0.3.0" }, - { name = "redis", extras = ["hiredis"], specifier = "~=7.3.0" }, + { name = "redis", extras = ["hiredis"], specifier = "~=7.4.0" }, { name = "resend", specifier = "~=2.26.0" }, { name = "sendgrid", specifier = "~=6.12.3" }, { name = "sentry-sdk", extras = ["flask"], specifier = "~=2.55.0" }, @@ -1750,7 +1751,6 @@ requires-dist = [ { name = "unstructured", extras = ["docx", "epub", "md", "ppt", "pptx"], specifier = "~=0.21.5" }, { name = "weave", specifier = ">=0.52.16" }, { name = "weaviate-client", specifier = "==4.20.4" }, - { name = "webvtt-py", specifier = "~=0.5.1" }, { name = "yarl", specifier = "~=1.23.0" }, ] @@ -1793,7 +1793,6 @@ dev = [ { name = "types-greenlet", specifier = "~=3.3.0" }, { name = "types-html5lib", specifier = "~=1.1.11" }, { name = "types-jmespath", specifier = ">=1.0.2.20240106" }, - { name = "types-jsonschema", specifier = "~=4.26.0" }, { name = "types-markdown", specifier = "~=3.10.2" }, { name = "types-oauthlib", specifier = "~=3.3.0" }, { name = "types-objgraph", specifier = "~=3.6.0" }, @@ -1811,7 +1810,7 @@ dev = [ { name = "types-pywin32", specifier = "~=311.0.0" }, { name = "types-pyyaml", specifier = "~=6.0.12" }, { name = "types-redis", specifier = ">=4.6.0.20241004" }, - { name = "types-regex", specifier = "~=2026.2.28" }, + { name = "types-regex", specifier = "~=2026.3.32" }, { name = "types-setuptools", specifier = ">=80.9.0" }, { name = "types-shapely", specifier = "~=2.1.0" }, { name = "types-simplejson", specifier = ">=3.20.0" }, @@ -1839,7 +1838,7 @@ vdb = [ { name = "alibabacloud-gpdb20160503", specifier = "~=5.1.0" }, { name = "alibabacloud-tea-openapi", specifier = "~=0.4.3" }, { name = "chromadb", specifier = "==0.5.20" }, - { name = "clickhouse-connect", specifier = "~=0.14.1" }, + { name = "clickhouse-connect", specifier = "~=0.15.0" }, { name = "clickzetta-connector-python", specifier = ">=0.8.102" }, { name = "couchbase", specifier = "~=4.5.0" }, { name = "elasticsearch", specifier = "==8.14.0" }, @@ -1855,13 +1854,13 @@ vdb = [ { name = "pymochow", specifier = "==2.3.6" }, { name = "pyobvector", specifier = "~=0.2.17" }, { name = "qdrant-client", specifier = "==1.9.0" }, - { name = "tablestore", specifier = "==6.4.1" }, - { name = "tcvectordb", specifier = "~=2.0.0" }, + { name = "tablestore", specifier = "==6.4.2" }, + { name = "tcvectordb", specifier = "~=2.1.0" }, { name = "tidb-vector", specifier = "==0.0.15" }, { name = "upstash-vector", specifier = "==0.8.0" }, { name = "volcengine-compat", specifier = "~=1.0.0" }, { name = "weaviate-client", specifier = "==4.20.4" }, - { name = "xinference-client", specifier = "~=2.3.1" }, + { name = "xinference-client", specifier = "~=2.4.0" }, ] [[package]] @@ -2024,14 +2023,14 @@ wheels = [ [[package]] name = "faker" -version = "40.11.0" +version = "40.11.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "tzdata", marker = "sys_platform == 'win32'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/94/dc/b68e5378e5a7db0ab776efcdd53b6fe374b29d703e156fd5bb4c5437069e/faker-40.11.0.tar.gz", hash = "sha256:7c419299103b13126bd02ec14bd2b47b946edb5a5eedf305e66a193b25f9a734", size = 1957570, upload-time = "2026-03-13T14:36:11.844Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/e5/b16bf568a2f20fe7423282db4a4059dbcadef70e9029c1c106836f8edd84/faker-40.11.1.tar.gz", hash = "sha256:61965046e79e8cfde4337d243eac04c0d31481a7c010033141103b43f603100c", size = 1957415, upload-time = "2026-03-23T14:05:50.233Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b1/fa/a86c6ba66f0308c95b9288b1e3eaccd934b545646f63494a86f1ec2f8c8e/faker-40.11.0-py3-none-any.whl", hash = "sha256:0e9816c950528d2a37d74863f3ef389ea9a3a936cbcde0b11b8499942e25bf90", size = 1989457, upload-time = "2026-03-13T14:36:09.792Z" }, + { url = "https://files.pythonhosted.org/packages/fc/ec/3c4b78eb0d2f6a81fb8cc9286745845bff661e6815741eff7a6ac5fcc9ea/faker-40.11.1-py3-none-any.whl", hash = "sha256:3af3a213ba8fb33ce6ba2af7aef2ac91363dae35d0cec0b2b0337d189e5bee2a", size = 1989484, upload-time = "2026-03-23T14:05:48.793Z" }, ] [[package]] @@ -2472,7 +2471,7 @@ wheels = [ [[package]] name = "google-cloud-aiplatform" -version = "1.142.0" +version = "1.143.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "docstring-parser" }, @@ -2488,9 +2487,9 @@ dependencies = [ { name = "pydantic" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/41/0d/3063a0512d60cf18854a279e00ccb796429545464345ef821cf77cb93d05/google_cloud_aiplatform-1.142.0.tar.gz", hash = "sha256:87b49e002703dc14885093e9b264587db84222bef5f70f5a442d03f41beecdd1", size = 10207993, upload-time = "2026-03-20T22:49:13.797Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a7/08/939fb05870fdf155410a927e22f5b053d49f18e215618e102fba1d8bb147/google_cloud_aiplatform-1.143.0.tar.gz", hash = "sha256:1f0124a89795a6b473deb28724dd37d95334205df3a9c9c48d0b8d7a3d5d5cc4", size = 10215389, upload-time = "2026-03-25T18:30:15.444Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/8b/f29646d3fa940f0e38cfcc12137f4851856b50d7486a3c05103ebc78d82d/google_cloud_aiplatform-1.142.0-py2.py3-none-any.whl", hash = "sha256:17c91db9b613cbbafb2c36335b123686aeb2b4b8448be5134b565ae07165a39a", size = 8388991, upload-time = "2026-03-20T22:49:10.334Z" }, + { url = "https://files.pythonhosted.org/packages/90/14/16323e604e79dc63b528268f97a841c2c29dd8eb16395de6bf530c1a5ebe/google_cloud_aiplatform-1.143.0-py2.py3-none-any.whl", hash = "sha256:78df97d044859f743a9cc48b89a260d33579b0d548b1589bb3ae9f4c2afc0c5a", size = 8392705, upload-time = "2026-03-25T18:30:11.496Z" }, ] [[package]] @@ -2543,7 +2542,7 @@ wheels = [ [[package]] name = "google-cloud-storage" -version = "3.10.0" +version = "3.10.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "google-api-core" }, @@ -2553,9 +2552,9 @@ dependencies = [ { name = "google-resumable-media" }, { name = "requests" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/7a/e3/747759eebc72e420c25903d6bc231d0ceb110b66ac7e6ee3f350417152cd/google_cloud_storage-3.10.0.tar.gz", hash = "sha256:1aeebf097c27d718d84077059a28d7e87f136f3700212215f1ceeae1d1c5d504", size = 17309829, upload-time = "2026-03-18T15:54:11.875Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4c/47/205eb8e9a1739b5345843e5a425775cbdc472cc38e7eda082ba5b8d02450/google_cloud_storage-3.10.1.tar.gz", hash = "sha256:97db9aa4460727982040edd2bd13ff3d5e2260b5331ad22895802da1fc2a5286", size = 17309950, upload-time = "2026-03-23T09:35:23.409Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/29/e2/d58442f4daee5babd9255cf492a1f3d114357164072f8339a22a3ad460a2/google_cloud_storage-3.10.0-py3-none-any.whl", hash = "sha256:0072e7783b201e45af78fd9779894cdb6bec2bf922ee932f3fcc16f8bce9b9a3", size = 324382, upload-time = "2026-03-18T15:54:10.091Z" }, + { url = "https://files.pythonhosted.org/packages/ad/ff/ca9ab2417fa913d75aae38bf40bf856bb2749a604b2e0f701b37cfcd23cc/google_cloud_storage-3.10.1-py3-none-any.whl", hash = "sha256:a72f656759b7b99bda700f901adcb3425a828d4a29f911bc26b3ea79c5b1217f", size = 324453, upload-time = "2026-03-23T09:35:21.368Z" }, ] [[package]] @@ -2613,14 +2612,14 @@ wheels = [ [[package]] name = "googleapis-common-protos" -version = "1.73.0" +version = "1.73.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "protobuf" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/99/96/a0205167fa0154f4a542fd6925bdc63d039d88dab3588b875078107e6f06/googleapis_common_protos-1.73.0.tar.gz", hash = "sha256:778d07cd4fbeff84c6f7c72102f0daf98fa2bfd3fa8bea426edc545588da0b5a", size = 147323, upload-time = "2026-03-06T21:53:09.727Z" } +sdist = { url = "https://files.pythonhosted.org/packages/a1/c0/4a54c386282c13449eca8bbe2ddb518181dc113e78d240458a68856b4d69/googleapis_common_protos-1.73.1.tar.gz", hash = "sha256:13114f0e9d2391756a0194c3a8131974ed7bffb06086569ba193364af59163b6", size = 147506, upload-time = "2026-03-26T22:17:38.451Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/69/28/23eea8acd65972bbfe295ce3666b28ac510dfcb115fac089d3edb0feb00a/googleapis_common_protos-1.73.0-py3-none-any.whl", hash = "sha256:dfdaaa2e860f242046be561e6d6cb5c5f1541ae02cfbcb034371aadb2942b4e8", size = 297578, upload-time = "2026-03-06T21:52:33.933Z" }, + { url = "https://files.pythonhosted.org/packages/dc/82/fcb6520612bec0c39b973a6c0954b6a0d948aadfe8f7e9487f60ceb8bfa6/googleapis_common_protos-1.73.1-py3-none-any.whl", hash = "sha256:e51f09eb0a43a8602f5a915870972e6b4a394088415c79d79605a46d8e826ee8", size = 297556, upload-time = "2026-03-26T22:15:58.455Z" }, ] [package.optional-dependencies] @@ -2652,6 +2651,34 @@ requests = [ { name = "requests-toolbelt" }, ] +[[package]] +name = "graphon" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "charset-normalizer" }, + { name = "httpx" }, + { name = "json-repair" }, + { name = "jsonschema" }, + { name = "orjson" }, + { name = "pandas", extra = ["excel"] }, + { name = "pydantic" }, + { name = "pydantic-extra-types" }, + { name = "pypandoc" }, + { name = "pypdfium2" }, + { name = "python-docx" }, + { name = "pyyaml" }, + { name = "tiktoken" }, + { name = "transformers" }, + { name = "typing-extensions" }, + { name = "unstructured", extra = ["docx", "epub", "md", "ppt", "pptx"] }, + { name = "webvtt-py" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/fc/0a5342a1c29bc367c2254c170ef130a84a60d8cd1c9cc84a7a85e96c1042/graphon-0.1.2.tar.gz", hash = "sha256:a2210629f93258ad2e7cbe85b5d4c6826814f6c679aa2a23ca100511363b9240", size = 214744, upload-time = "2026-03-27T20:09:53.802Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d5/46/65b5e366ec2d7017b6d6448e2635b3772d86840a6f7297277471b1bfbfbd/graphon-0.1.2-py3-none-any.whl", hash = "sha256:79f0c7796de7b8642d070730bb8bdaf1c68ccdfcecac38e0b2282e0543f0a6db", size = 314398, upload-time = "2026-03-27T20:09:52.524Z" }, +] + [[package]] name = "graphql-core" version = "3.2.7" @@ -2843,14 +2870,14 @@ wheels = [ [[package]] name = "gunicorn" -version = "25.1.0" +version = "25.3.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "packaging" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/66/13/ef67f59f6a7896fdc2c1d62b5665c5219d6b0a9a1784938eb9a28e55e128/gunicorn-25.1.0.tar.gz", hash = "sha256:1426611d959fa77e7de89f8c0f32eed6aa03ee735f98c01efba3e281b1c47616", size = 594377, upload-time = "2026-02-13T11:09:58.989Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c4/f4/e78fa054248fab913e2eab0332c6c2cb07421fca1ce56d8fe43b6aef57a4/gunicorn-25.3.0.tar.gz", hash = "sha256:f74e1b2f9f76f6cd1ca01198968bd2dd65830edc24b6e8e4d78de8320e2fe889", size = 634883, upload-time = "2026-03-27T00:00:26.092Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/da/73/4ad5b1f6a2e21cf1e85afdaad2b7b1a933985e2f5d679147a1953aaa192c/gunicorn-25.1.0-py3-none-any.whl", hash = "sha256:d0b1236ccf27f72cfe14bce7caadf467186f19e865094ca84221424e839b8b8b", size = 197067, upload-time = "2026-02-13T11:09:57.146Z" }, + { url = "https://files.pythonhosted.org/packages/43/c8/8aaf447698c4d59aa853fd318eed300b5c9e44459f242ab8ead6c9c09792/gunicorn-25.3.0-py3-none-any.whl", hash = "sha256:cacea387dab08cd6776501621c295a904fe8e3b7aae9a1a3cbb26f4e7ed54660", size = 208403, upload-time = "2026-03-27T00:00:27.386Z" }, ] [[package]] @@ -3083,14 +3110,14 @@ wheels = [ [[package]] name = "hypothesis" -version = "6.151.9" +version = "6.151.10" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "sortedcontainers" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/19/e1/ef365ff480903b929d28e057f57b76cae51a30375943e33374ec9a165d9c/hypothesis-6.151.9.tar.gz", hash = "sha256:2f284428dda6c3c48c580de0e18470ff9c7f5ef628a647ee8002f38c3f9097ca", size = 463534, upload-time = "2026-02-16T22:59:23.09Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/dd/633e2cd62377333b7681628aee2ec1d88166f5bdf916b08c98b1e8288ad3/hypothesis-6.151.10.tar.gz", hash = "sha256:6c9565af8b4aa3a080b508f66ce9c2a77dd613c7e9073e27fc7e4ef9f45f8a27", size = 463762, upload-time = "2026-03-29T01:06:22.19Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c4/f7/5cc291d701094754a1d327b44d80a44971e13962881d9a400235726171da/hypothesis-6.151.9-py3-none-any.whl", hash = "sha256:7b7220585c67759b1b1ef839b1e6e9e3d82ed468cfc1ece43c67184848d7edd9", size = 529307, upload-time = "2026-02-16T22:59:20.443Z" }, + { url = "https://files.pythonhosted.org/packages/40/da/439bb2e451979f5e88c13bbebc3e9e17754429cfb528c93677b2bd81783b/hypothesis-6.151.10-py3-none-any.whl", hash = "sha256:b0d7728f0c8c2be009f89fcdd6066f70c5439aa0f94adbb06e98261d05f49b05", size = 529493, upload-time = "2026-03-29T01:06:19.161Z" }, ] [[package]] @@ -3905,7 +3932,7 @@ wheels = [ [[package]] name = "nltk" -version = "3.9.3" +version = "3.9.4" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, @@ -3913,9 +3940,9 @@ dependencies = [ { name = "regex" }, { name = "tqdm" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/e1/8f/915e1c12df07c70ed779d18ab83d065718a926e70d3ea33eb0cd66ffb7c0/nltk-3.9.3.tar.gz", hash = "sha256:cb5945d6424a98d694c2b9a0264519fab4363711065a46aa0ae7a2195b92e71f", size = 2923673, upload-time = "2026-02-24T12:05:53.833Z" } +sdist = { url = "https://files.pythonhosted.org/packages/74/a1/b3b4adf15585a5bc4c357adde150c01ebeeb642173ded4d871e89468767c/nltk-3.9.4.tar.gz", hash = "sha256:ed03bc098a40481310320808b2db712d95d13ca65b27372f8a403949c8b523d0", size = 2946864, upload-time = "2026-03-24T06:13:40.641Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c2/7e/9af5a710a1236e4772de8dfcc6af942a561327bb9f42b5b4a24d0cf100fd/nltk-3.9.3-py3-none-any.whl", hash = "sha256:60b3db6e9995b3dd976b1f0fa7dec22069b2677e759c28eb69b62ddd44870522", size = 1525385, upload-time = "2026-02-24T12:05:46.54Z" }, + { url = "https://files.pythonhosted.org/packages/9d/91/04e965f8e717ba0ab4bdca5c112deeab11c9e750d94c4d4602f050295d39/nltk-3.9.4-py3-none-any.whl", hash = "sha256:f2fa301c3a12718ce4a0e9305c5675299da5ad9e26068218b69d692fda84828f", size = 1552087, upload-time = "2026-03-24T06:13:38.47Z" }, ] [[package]] @@ -4462,7 +4489,7 @@ wheels = [ [[package]] name = "opik" -version = "1.10.45" +version = "1.10.54" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "boto3-stubs", extra = ["bedrock-runtime"] }, @@ -4481,9 +4508,9 @@ dependencies = [ { name = "tqdm" }, { name = "uuid6" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/85/17/edea6308347cec62e6828de7c573c596559c502b54fa4f0c88a52e2e81f5/opik-1.10.45.tar.gz", hash = "sha256:d8d8627ba03d12def46965e03d58f611daaf5cf878b3d087c53fe1159788c140", size = 789876, upload-time = "2026-03-20T11:35:12.457Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/c9/ecc68c5ae32bf5b1074bdc713cb1543b8e2a46c58c814bf150fecf50f272/opik-1.10.54.tar.gz", hash = "sha256:46e29abf4656bd80b9cb339659d24ecf97b61f37c3fde594de75e5f59953e9d3", size = 812757, upload-time = "2026-03-27T11:23:06.109Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/b7/17/150e9eecfa28cb23f7a0bfe83ae1486a11022b97fe6d12328b455784658d/opik-1.10.45-py3-none-any.whl", hash = "sha256:e8050d9e5e0d92ff587f156eacbdd02099897f39cfe79a98380b6c8ae9906b95", size = 1337714, upload-time = "2026-03-20T11:35:10.237Z" }, + { url = "https://files.pythonhosted.org/packages/58/91/1ae4e8a349da0620a6f0a4fc51cd00c3e75176939d022e8684379aee2928/opik-1.10.54-py3-none-any.whl", hash = "sha256:5f8ddabe5283ebe08d455e81b188d6e09ce1d1efa989f8b05567ef70f1e9aeda", size = 1379008, upload-time = "2026-03-27T11:23:04.582Z" }, ] [[package]] @@ -5249,7 +5276,7 @@ crypto = [ [[package]] name = "pymilvus" -version = "2.6.10" +version = "2.6.11" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -5261,9 +5288,9 @@ dependencies = [ { name = "requests" }, { name = "setuptools" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/9e/85/90362066ccda5ff6fec693a55693cde659fdcd36d08f1bd7012ae958248d/pymilvus-2.6.10.tar.gz", hash = "sha256:58a44ee0f1dddd7727ae830ef25325872d8946f029d801a37105164e6699f1b8", size = 1561042, upload-time = "2026-03-13T09:54:22.441Z" } +sdist = { url = "https://files.pythonhosted.org/packages/c9/e6/0adc3b374f5c5d1eebd4f551b455c6865c449b170b17545001b208e2b153/pymilvus-2.6.11.tar.gz", hash = "sha256:a40c10322cde25184a8c3d84993a14dfb67ad2bdcfc5dff7e68b11a79ff8f6d8", size = 1583634, upload-time = "2026-03-27T06:25:46.023Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/88/10/fe7fbb6795aa20038afd55e9c653991e7c69fb24c741ebb39ba3b0aa5c13/pymilvus-2.6.10-py3-none-any.whl", hash = "sha256:a048b6f3ebad93742bca559beabf44fe578f0983555a109c4436b5fb2c1dbd40", size = 312797, upload-time = "2026-03-13T09:54:21.081Z" }, + { url = "https://files.pythonhosted.org/packages/9c/1c/bccb331d71f824738f80f11e9b8b4da47973c903826355526ae4fa2b762f/pymilvus-2.6.11-py3-none-any.whl", hash = "sha256:a11e1718b15045361c71ca671b959900cb7e2faae863c896f6b7e87bf2e4d10a", size = 315252, upload-time = "2026-03-27T06:25:44.215Z" }, ] [[package]] @@ -5340,11 +5367,11 @@ wheels = [ [[package]] name = "pypdf" -version = "6.9.1" +version = "6.9.2" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/f9/fb/dc2e8cb006e80b0020ed20d8649106fe4274e82d8e756ad3e24ade19c0df/pypdf-6.9.1.tar.gz", hash = "sha256:ae052407d33d34de0c86c5c729be6d51010bf36e03035a8f23ab449bca52377d", size = 5311551, upload-time = "2026-03-17T10:46:07.876Z" } +sdist = { url = "https://files.pythonhosted.org/packages/31/83/691bdb309306232362503083cb15777491045dd54f45393a317dc7d8082f/pypdf-6.9.2.tar.gz", hash = "sha256:7f850faf2b0d4ab936582c05da32c52214c2b089d61a316627b5bfb5b0dab46c", size = 5311837, upload-time = "2026-03-23T14:53:27.983Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f9/f4/75543fa802b86e72f87e9395440fe1a89a6d149887e3e55745715c3352ac/pypdf-6.9.1-py3-none-any.whl", hash = "sha256:f35a6a022348fae47e092a908339a8f3dc993510c026bb39a96718fc7185e89f", size = 333661, upload-time = "2026-03-17T10:46:06.286Z" }, + { url = "https://files.pythonhosted.org/packages/a5/7e/c85f41243086a8fe5d1baeba527cb26a1918158a565932b41e0f7c0b32e9/pypdf-6.9.2-py3-none-any.whl", hash = "sha256:662cf29bcb419a36a1365232449624ab40b7c2d0cfc28e54f42eeecd1fd7e844", size = 333744, upload-time = "2026-03-23T14:53:26.573Z" }, ] [[package]] @@ -5785,14 +5812,14 @@ wheels = [ [[package]] name = "redis" -version = "7.3.0" +version = "7.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "async-timeout", marker = "python_full_version < '3.11.3'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/da/82/4d1a5279f6c1251d3d2a603a798a1137c657de9b12cfc1fba4858232c4d2/redis-7.3.0.tar.gz", hash = "sha256:4d1b768aafcf41b01022410b3cc4f15a07d9b3d6fe0c66fc967da2c88e551034", size = 4928081, upload-time = "2026-03-06T18:18:16.287Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7b/7f/3759b1d0d72b7c92f0d70ffd9dc962b7b7b5ee74e135f9d7d8ab06b8a318/redis-7.4.0.tar.gz", hash = "sha256:64a6ea7bf567ad43c964d2c30d82853f8df927c5c9017766c55a1d1ed95d18ad", size = 4943913, upload-time = "2026-03-24T09:14:37.53Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/f0/28/84e57fce7819e81ec5aa1bd31c42b89607241f4fb1a3ea5b0d2dbeaea26c/redis-7.3.0-py3-none-any.whl", hash = "sha256:9d4fcb002a12a5e3c3fbe005d59c48a2cc231f87fbb2f6b70c2d89bb64fec364", size = 404379, upload-time = "2026-03-06T18:18:14.583Z" }, + { url = "https://files.pythonhosted.org/packages/74/3a/95deec7db1eb53979973ebd156f3369a72732208d1391cd2e5d127062a32/redis-7.4.0-py3-none-any.whl", hash = "sha256:a9c74a5c893a5ef8455a5adb793a31bb70feb821c86eccb62eebef5a19c429ec", size = 409772, upload-time = "2026-03-24T09:14:35.968Z" }, ] [package.optional-dependencies] @@ -5852,7 +5879,7 @@ wheels = [ [[package]] name = "requests" -version = "2.32.5" +version = "2.33.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "certifi" }, @@ -5860,9 +5887,9 @@ dependencies = [ { name = "idna" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c9/74/b3ff8e6c8446842c3f5c837e9c3dfcfe2018ea6ecef224c710c85ef728f4/requests-2.32.5.tar.gz", hash = "sha256:dbba0bac56e100853db0ea71b82b4dfd5fe2bf6d3754a8893c3af500cec7d7cf", size = 134517, upload-time = "2025-08-18T20:46:02.573Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/64/8860370b167a9721e8956ae116825caff829224fbca0ca6e7bf8ddef8430/requests-2.33.0.tar.gz", hash = "sha256:c7ebc5e8b0f21837386ad0e1c8fe8b829fa5f544d8df3b2253bff14ef29d7652", size = 134232, upload-time = "2026-03-25T15:10:41.586Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/1e/db/4254e3eabe8020b458f1a747140d32277ec7a271daf1d235b70dc0b4e6e3/requests-2.32.5-py3-none-any.whl", hash = "sha256:2462f94637a34fd532264295e186976db0f5d453d1cdd31473c85a6a161affb6", size = 64738, upload-time = "2025-08-18T20:46:00.542Z" }, + { url = "https://files.pythonhosted.org/packages/56/5d/c814546c2333ceea4ba42262d8c4d55763003e767fa169adc693bd524478/requests-2.33.0-py3-none-any.whl", hash = "sha256:3324635456fa185245e24865e810cecec7b4caf933d7eb133dcde67d48cee69b", size = 65017, upload-time = "2026-03-25T15:10:40.382Z" }, ] [[package]] @@ -5981,27 +6008,27 @@ wheels = [ [[package]] name = "ruff" -version = "0.15.7" +version = "0.15.8" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/a1/22/9e4f66ee588588dc6c9af6a994e12d26e19efbe874d1a909d09a6dac7a59/ruff-0.15.7.tar.gz", hash = "sha256:04f1ae61fc20fe0b148617c324d9d009b5f63412c0b16474f3d5f1a1a665f7ac", size = 4601277, upload-time = "2026-03-19T16:26:22.605Z" } +sdist = { url = "https://files.pythonhosted.org/packages/14/b0/73cf7550861e2b4824950b8b52eebdcc5adc792a00c514406556c5b80817/ruff-0.15.8.tar.gz", hash = "sha256:995f11f63597ee362130d1d5a327a87cb6f3f5eae3094c620bcc632329a4d26e", size = 4610921, upload-time = "2026-03-26T18:39:38.675Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/41/2f/0b08ced94412af091807b6119ca03755d651d3d93a242682bf020189db94/ruff-0.15.7-py3-none-linux_armv6l.whl", hash = "sha256:a81cc5b6910fb7dfc7c32d20652e50fa05963f6e13ead3c5915c41ac5d16668e", size = 10489037, upload-time = "2026-03-19T16:26:32.47Z" }, - { url = "https://files.pythonhosted.org/packages/91/4a/82e0fa632e5c8b1eba5ee86ecd929e8ff327bbdbfb3c6ac5d81631bef605/ruff-0.15.7-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:722d165bd52403f3bdabc0ce9e41fc47070ac56d7a91b4e0d097b516a53a3477", size = 10955433, upload-time = "2026-03-19T16:27:00.205Z" }, - { url = "https://files.pythonhosted.org/packages/ab/10/12586735d0ff42526ad78c049bf51d7428618c8b5c467e72508c694119df/ruff-0.15.7-py3-none-macosx_11_0_arm64.whl", hash = "sha256:7fbc2448094262552146cbe1b9643a92f66559d3761f1ad0656d4991491af49e", size = 10269302, upload-time = "2026-03-19T16:26:26.183Z" }, - { url = "https://files.pythonhosted.org/packages/eb/5d/32b5c44ccf149a26623671df49cbfbd0a0ae511ff3df9d9d2426966a8d57/ruff-0.15.7-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6b39329b60eba44156d138275323cc726bbfbddcec3063da57caa8a8b1d50adf", size = 10607625, upload-time = "2026-03-19T16:27:03.263Z" }, - { url = "https://files.pythonhosted.org/packages/5d/f1/f0001cabe86173aaacb6eb9bb734aa0605f9a6aa6fa7d43cb49cbc4af9c9/ruff-0.15.7-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:87768c151808505f2bfc93ae44e5f9e7c8518943e5074f76ac21558ef5627c85", size = 10324743, upload-time = "2026-03-19T16:27:09.791Z" }, - { url = "https://files.pythonhosted.org/packages/7a/87/b8a8f3d56b8d848008559e7c9d8bf367934d5367f6d932ba779456e2f73b/ruff-0.15.7-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fb0511670002c6c529ec66c0e30641c976c8963de26a113f3a30456b702468b0", size = 11138536, upload-time = "2026-03-19T16:27:06.101Z" }, - { url = "https://files.pythonhosted.org/packages/e4/f2/4fd0d05aab0c5934b2e1464784f85ba2eab9d54bffc53fb5430d1ed8b829/ruff-0.15.7-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e0d19644f801849229db8345180a71bee5407b429dd217f853ec515e968a6912", size = 11994292, upload-time = "2026-03-19T16:26:48.718Z" }, - { url = "https://files.pythonhosted.org/packages/64/22/fc4483871e767e5e95d1622ad83dad5ebb830f762ed0420fde7dfa9d9b08/ruff-0.15.7-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4806d8e09ef5e84eb19ba833d0442f7e300b23fe3f0981cae159a248a10f0036", size = 11398981, upload-time = "2026-03-19T16:26:54.513Z" }, - { url = "https://files.pythonhosted.org/packages/b0/99/66f0343176d5eab02c3f7fcd2de7a8e0dd7a41f0d982bee56cd1c24db62b/ruff-0.15.7-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dce0896488562f09a27b9c91b1f58a097457143931f3c4d519690dea54e624c5", size = 11242422, upload-time = "2026-03-19T16:26:29.277Z" }, - { url = "https://files.pythonhosted.org/packages/5d/3a/a7060f145bfdcce4c987ea27788b30c60e2c81d6e9a65157ca8afe646328/ruff-0.15.7-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:1852ce241d2bc89e5dc823e03cff4ce73d816b5c6cdadd27dbfe7b03217d2a12", size = 11232158, upload-time = "2026-03-19T16:26:42.321Z" }, - { url = "https://files.pythonhosted.org/packages/a7/53/90fbb9e08b29c048c403558d3cdd0adf2668b02ce9d50602452e187cd4af/ruff-0.15.7-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:5f3e4b221fb4bd293f79912fc5e93a9063ebd6d0dcbd528f91b89172a9b8436c", size = 10577861, upload-time = "2026-03-19T16:26:57.459Z" }, - { url = "https://files.pythonhosted.org/packages/2f/aa/5f486226538fe4d0f0439e2da1716e1acf895e2a232b26f2459c55f8ddad/ruff-0.15.7-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:b15e48602c9c1d9bdc504b472e90b90c97dc7d46c7028011ae67f3861ceba7b4", size = 10327310, upload-time = "2026-03-19T16:26:35.909Z" }, - { url = "https://files.pythonhosted.org/packages/99/9e/271afdffb81fe7bfc8c43ba079e9d96238f674380099457a74ccb3863857/ruff-0.15.7-py3-none-musllinux_1_2_i686.whl", hash = "sha256:1b4705e0e85cedc74b0a23cf6a179dbb3df184cb227761979cc76c0440b5ab0d", size = 10840752, upload-time = "2026-03-19T16:26:45.723Z" }, - { url = "https://files.pythonhosted.org/packages/bf/29/a4ae78394f76c7759953c47884eb44de271b03a66634148d9f7d11e721bd/ruff-0.15.7-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:112c1fa316a558bb34319282c1200a8bf0495f1b735aeb78bfcb2991e6087580", size = 11336961, upload-time = "2026-03-19T16:26:39.076Z" }, - { url = "https://files.pythonhosted.org/packages/26/6b/8786ba5736562220d588a2f6653e6c17e90c59ced34a2d7b512ef8956103/ruff-0.15.7-py3-none-win32.whl", hash = "sha256:6d39e2d3505b082323352f733599f28169d12e891f7dd407f2d4f54b4c2886de", size = 10582538, upload-time = "2026-03-19T16:26:15.992Z" }, - { url = "https://files.pythonhosted.org/packages/2b/e9/346d4d3fffc6871125e877dae8d9a1966b254fbd92a50f8561078b88b099/ruff-0.15.7-py3-none-win_amd64.whl", hash = "sha256:4d53d712ddebcd7dace1bc395367aec12c057aacfe9adbb6d832302575f4d3a1", size = 11755839, upload-time = "2026-03-19T16:26:19.897Z" }, - { url = "https://files.pythonhosted.org/packages/8f/e8/726643a3ea68c727da31570bde48c7a10f1aa60eddd628d94078fec586ff/ruff-0.15.7-py3-none-win_arm64.whl", hash = "sha256:18e8d73f1c3fdf27931497972250340f92e8c861722161a9caeb89a58ead6ed2", size = 11023304, upload-time = "2026-03-19T16:26:51.669Z" }, + { url = "https://files.pythonhosted.org/packages/4a/92/c445b0cd6da6e7ae51e954939cb69f97e008dbe750cfca89b8cedc081be7/ruff-0.15.8-py3-none-linux_armv6l.whl", hash = "sha256:cbe05adeba76d58162762d6b239c9056f1a15a55bd4b346cfd21e26cd6ad7bc7", size = 10527394, upload-time = "2026-03-26T18:39:41.566Z" }, + { url = "https://files.pythonhosted.org/packages/eb/92/f1c662784d149ad1414cae450b082cf736430c12ca78367f20f5ed569d65/ruff-0.15.8-py3-none-macosx_10_12_x86_64.whl", hash = "sha256:d3e3d0b6ba8dca1b7ef9ab80a28e840a20070c4b62e56d675c24f366ef330570", size = 10905693, upload-time = "2026-03-26T18:39:30.364Z" }, + { url = "https://files.pythonhosted.org/packages/ca/f2/7a631a8af6d88bcef997eb1bf87cc3da158294c57044aafd3e17030613de/ruff-0.15.8-py3-none-macosx_11_0_arm64.whl", hash = "sha256:6ee3ae5c65a42f273f126686353f2e08ff29927b7b7e203b711514370d500de3", size = 10323044, upload-time = "2026-03-26T18:39:33.37Z" }, + { url = "https://files.pythonhosted.org/packages/67/18/1bf38e20914a05e72ef3b9569b1d5c70a7ef26cd188d69e9ca8ef588d5bf/ruff-0.15.8-py3-none-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdce027ada77baa448077ccc6ebb2fa9c3c62fd110d8659d601cf2f475858d94", size = 10629135, upload-time = "2026-03-26T18:39:44.142Z" }, + { url = "https://files.pythonhosted.org/packages/d2/e9/138c150ff9af60556121623d41aba18b7b57d95ac032e177b6a53789d279/ruff-0.15.8-py3-none-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:12e617fc01a95e5821648a6df341d80456bd627bfab8a829f7cfc26a14a4b4a3", size = 10348041, upload-time = "2026-03-26T18:39:52.178Z" }, + { url = "https://files.pythonhosted.org/packages/02/f1/5bfb9298d9c323f842c5ddeb85f1f10ef51516ac7a34ba446c9347d898df/ruff-0.15.8-py3-none-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:432701303b26416d22ba696c39f2c6f12499b89093b61360abc34bcc9bf07762", size = 11121987, upload-time = "2026-03-26T18:39:55.195Z" }, + { url = "https://files.pythonhosted.org/packages/10/11/6da2e538704e753c04e8d86b1fc55712fdbdcc266af1a1ece7a51fff0d10/ruff-0.15.8-py3-none-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:d910ae974b7a06a33a057cb87d2a10792a3b2b3b35e33d2699fdf63ec8f6b17a", size = 11951057, upload-time = "2026-03-26T18:39:19.18Z" }, + { url = "https://files.pythonhosted.org/packages/83/f0/c9208c5fd5101bf87002fed774ff25a96eea313d305f1e5d5744698dc314/ruff-0.15.8-py3-none-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2033f963c43949d51e6fdccd3946633c6b37c484f5f98c3035f49c27395a8ab8", size = 11464613, upload-time = "2026-03-26T18:40:06.301Z" }, + { url = "https://files.pythonhosted.org/packages/f8/22/d7f2fabdba4fae9f3b570e5605d5eb4500dcb7b770d3217dca4428484b17/ruff-0.15.8-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0f29b989a55572fb885b77464cf24af05500806ab4edf9a0fd8977f9759d85b1", size = 11257557, upload-time = "2026-03-26T18:39:57.972Z" }, + { url = "https://files.pythonhosted.org/packages/71/8c/382a9620038cf6906446b23ce8632ab8c0811b8f9d3e764f58bedd0c9a6f/ruff-0.15.8-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:ac51d486bf457cdc985a412fb1801b2dfd1bd8838372fc55de64b1510eff4bec", size = 11169440, upload-time = "2026-03-26T18:39:22.205Z" }, + { url = "https://files.pythonhosted.org/packages/4d/0d/0994c802a7eaaf99380085e4e40c845f8e32a562e20a38ec06174b52ef24/ruff-0.15.8-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:c9861eb959edab053c10ad62c278835ee69ca527b6dcd72b47d5c1e5648964f6", size = 10605963, upload-time = "2026-03-26T18:39:46.682Z" }, + { url = "https://files.pythonhosted.org/packages/19/aa/d624b86f5b0aad7cef6bbf9cd47a6a02dfdc4f72c92a337d724e39c9d14b/ruff-0.15.8-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:8d9a5b8ea13f26ae90838afc33f91b547e61b794865374f114f349e9036835fb", size = 10357484, upload-time = "2026-03-26T18:39:49.176Z" }, + { url = "https://files.pythonhosted.org/packages/35/c3/e0b7835d23001f7d999f3895c6b569927c4d39912286897f625736e1fd04/ruff-0.15.8-py3-none-musllinux_1_2_i686.whl", hash = "sha256:c2a33a529fb3cbc23a7124b5c6ff121e4d6228029cba374777bd7649cc8598b8", size = 10830426, upload-time = "2026-03-26T18:40:03.702Z" }, + { url = "https://files.pythonhosted.org/packages/f0/51/ab20b322f637b369383adc341d761eaaa0f0203d6b9a7421cd6e783d81b9/ruff-0.15.8-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:75e5cd06b1cf3f47a3996cfc999226b19aa92e7cce682dcd62f80d7035f98f49", size = 11345125, upload-time = "2026-03-26T18:39:27.799Z" }, + { url = "https://files.pythonhosted.org/packages/37/e6/90b2b33419f59d0f2c4c8a48a4b74b460709a557e8e0064cf33ad894f983/ruff-0.15.8-py3-none-win32.whl", hash = "sha256:bc1f0a51254ba21767bfa9a8b5013ca8149dcf38092e6a9eb704d876de94dc34", size = 10571959, upload-time = "2026-03-26T18:39:36.117Z" }, + { url = "https://files.pythonhosted.org/packages/1f/a2/ef467cb77099062317154c63f234b8a7baf7cb690b99af760c5b68b9ee7f/ruff-0.15.8-py3-none-win_amd64.whl", hash = "sha256:04f79eff02a72db209d47d665ba7ebcad609d8918a134f86cb13dd132159fc89", size = 11743893, upload-time = "2026-03-26T18:39:25.01Z" }, + { url = "https://files.pythonhosted.org/packages/15/e2/77be4fff062fa78d9b2a4dea85d14785dac5f1d0c1fb58ed52331f0ebe28/ruff-0.15.8-py3-none-win_arm64.whl", hash = "sha256:cf891fa8e3bb430c0e7fac93851a5978fc99c8fa2c053b57b118972866f8e5f2", size = 11048175, upload-time = "2026-03-26T18:40:01.06Z" }, ] [[package]] @@ -6402,7 +6429,7 @@ wheels = [ [[package]] name = "tablestore" -version = "6.4.1" +version = "6.4.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -6415,9 +6442,9 @@ dependencies = [ { name = "six" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/62/00/53f8eeb0016e7ad518f92b085de8855891d10581b42f86d15d1df7a56d33/tablestore-6.4.1.tar.gz", hash = "sha256:005c6939832f2ecd403e01220b7045de45f2e53f1ffaf0c2efc435810885fffb", size = 120319, upload-time = "2026-02-13T06:58:37.267Z" } +sdist = { url = "https://files.pythonhosted.org/packages/09/07/afa1d18521bab13bb813066892b73589937fcf68aea63a54b0b14dae17b5/tablestore-6.4.2.tar.gz", hash = "sha256:5251e14b7c7ebf3d49d37dde957b49c7dba04ee8715c2650109cc02f3b89cc77", size = 5071435, upload-time = "2026-03-26T15:39:06.498Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/cc/96/a132bdecb753dc9dc34124a53019da29672baaa34485c8c504895897ea96/tablestore-6.4.1-py3-none-any.whl", hash = "sha256:616898d294dfe22f0d427463c241c6788374cdb2ace9aaf85673ce2c2a18d7e0", size = 141556, upload-time = "2026-02-13T06:58:35.579Z" }, + { url = "https://files.pythonhosted.org/packages/c7/3f/5fb3e8e5de36934fe38986b4e861657cebb3a6dfd97d32224cd40fc66359/tablestore-6.4.2-py3-none-any.whl", hash = "sha256:98c4cffa5eace4a3ea6fc2425263e733093c2baa43537f25dbaaf02e2b7882d8", size = 5114987, upload-time = "2026-03-26T15:39:04.074Z" }, ] [[package]] @@ -6443,7 +6470,7 @@ sdist = { url = "https://files.pythonhosted.org/packages/20/81/be13f417065200182 [[package]] name = "tcvectordb" -version = "2.0.0" +version = "2.1.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "cachetools" }, @@ -6456,9 +6483,9 @@ dependencies = [ { name = "ujson" }, { name = "urllib3" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/16/21/3bcd466df20ac69408c0228b1c5e793cf3283085238d3ef5d352c556b6ad/tcvectordb-2.0.0.tar.gz", hash = "sha256:38c6ed17931b9bd702138941ca6cfe10b2b60301424ffa36b64a3c2686318941", size = 82209, upload-time = "2025-12-27T07:55:27.376Z" } +sdist = { url = "https://files.pythonhosted.org/packages/34/4c/3510489c20823c045a4f84c3f656b1af00b3fbbfa36efc494cf01492521f/tcvectordb-2.1.0.tar.gz", hash = "sha256:382615573f2b6d3e21535b686feac8895169b8eb56078fc73abb020676a1622f", size = 85691, upload-time = "2026-03-25T12:55:27.509Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/af/10/e807b273348edef3b321194bc13b67d2cd4df64e22f0404b9e39082415c7/tcvectordb-2.0.0-py3-none-any.whl", hash = "sha256:1731d9c6c0d17a4199872747ddfb1dd3feb26f14ffe7a657f8a5ac3af4ddcdd1", size = 96256, upload-time = "2025-12-27T07:55:24.362Z" }, + { url = "https://files.pythonhosted.org/packages/99/cf/7f340b4dc30ed0d2758915d1c2a4b2e9f0c90ce4f322b7cf17e571c80a45/tcvectordb-2.1.0-py3-none-any.whl", hash = "sha256:afbfc5f82bda70480921b2308148cbd0c51c8b45b3eef6cea64ddd003c7577e9", size = 99615, upload-time = "2026-03-25T12:55:26.004Z" }, ] [[package]] @@ -6850,18 +6877,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/61/91/915c4a6e6e9bd2bca3ec0c21c1771b175c59e204b85e57f3f572370fe753/types_jmespath-1.1.0.20260124-py3-none-any.whl", hash = "sha256:ec387666d446b15624215aa9cbd2867ffd885b6c74246d357c65e830c7a138b3", size = 11509, upload-time = "2026-01-24T03:18:45.536Z" }, ] -[[package]] -name = "types-jsonschema" -version = "4.26.0.20260202" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "referencing" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/a1/07/68f63e715eb327ed2f5292e29e8be99785db0f72c7664d2c63bd4dbdc29d/types_jsonschema-4.26.0.20260202.tar.gz", hash = "sha256:29831baa4308865a9aec547a61797a06fc152b0dac8dddd531e002f32265cb07", size = 16168, upload-time = "2026-02-02T04:11:22.585Z" } -wheels = [ - { url = "https://files.pythonhosted.org/packages/c1/06/962d4f364f779d7389cd31a1bb581907b057f52f0ace2c119a8dd8409db6/types_jsonschema-4.26.0.20260202-py3-none-any.whl", hash = "sha256:41c95343abc4de9264e333a55e95dfb4d401e463856d0164eec9cb182e8746da", size = 15914, upload-time = "2026-02-02T04:11:21.61Z" }, -] - [[package]] name = "types-markdown" version = "3.10.2.20260211" @@ -6873,11 +6888,11 @@ wheels = [ [[package]] name = "types-oauthlib" -version = "3.3.0.20250822" +version = "3.3.0.20260324" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/6a/6e/d08033f562053c459322333c46baa8cf8d2d8c18f30d46dd898c8fd8df77/types_oauthlib-3.3.0.20250822.tar.gz", hash = "sha256:2cd41587dd80c199e4230e3f086777e9ae525e89579c64afe5e0039ab09be9de", size = 25700, upload-time = "2025-08-22T03:02:41.378Z" } +sdist = { url = "https://files.pythonhosted.org/packages/91/38/543938f86d81bd6a78b8c355fe81bb8da0a26e4c28addfe3443e38a683d2/types_oauthlib-3.3.0.20260324.tar.gz", hash = "sha256:3c4cc07fa33886f881682237c1e445c5f1778b44efea118f4c1e4ede82cb52f2", size = 26030, upload-time = "2026-03-24T04:06:30.898Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/18/4b/00593b8b5d055550e1fcb9af2c42fa11b0a90bf16a94759a77bc1c3c0c72/types_oauthlib-3.3.0.20250822-py3-none-any.whl", hash = "sha256:b7f4c9b9eed0e020f454e0af800b10e93dd2efd196da65744b76910cce7e70d6", size = 48800, upload-time = "2025-08-22T03:02:40.427Z" }, + { url = "https://files.pythonhosted.org/packages/0e/60/26f0ddade4b2bb17b3d8f3ebaac436e5487caec28831da3d7ea309fe93b9/types_oauthlib-3.3.0.20260324-py3-none-any.whl", hash = "sha256:d24662033b04f4d50a2f1fed04c1b43ff2554aa037c1dafa0424f87100a46ccd", size = 48984, upload-time = "2026-03-24T04:06:29.696Z" }, ] [[package]] @@ -7028,11 +7043,11 @@ wheels = [ [[package]] name = "types-regex" -version = "2026.2.28.20260301" +version = "2026.3.32.20260329" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/3a/ed/106958cb686316113b748ed4209fa363fd92b15759d5409c3930fed36606/types_regex-2026.2.28.20260301.tar.gz", hash = "sha256:644c231db3f368908320170c14905731a7ae5fabdac0f60f5d6d12ecdd3bc8dd", size = 13157, upload-time = "2026-03-01T04:11:13.559Z" } +sdist = { url = "https://files.pythonhosted.org/packages/f2/d8/a3aca5775c573e56d201bbd76a827b84d851a4bce28e189e5acb9c7a0d15/types_regex-2026.3.32.20260329.tar.gz", hash = "sha256:12653e44694cb3e3ccdc39bab3d433d2a83fec1c01220e6871fd6f3cf434675c", size = 13111, upload-time = "2026-03-29T04:27:04.759Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/c7/bb/9bc26fcf5155bd25efeca35f8ba6bffb8b3c9da2baac8bf40067606418f3/types_regex-2026.2.28.20260301-py3-none-any.whl", hash = "sha256:7da7a1fe67528238176a5844fd435ca90617cf605341308686afbc579fdea5c0", size = 11130, upload-time = "2026-03-01T04:11:11.454Z" }, + { url = "https://files.pythonhosted.org/packages/89/f4/a1db307e56753c49fb15fc88d70fadeb3f38897b28cab645cddd18054c79/types_regex-2026.3.32.20260329-py3-none-any.whl", hash = "sha256:861d0893bcfe08a57eb7486a502014e29dc2721d46dd5130798fbccafdb31cc0", size = 11128, upload-time = "2026-03-29T04:27:03.854Z" }, ] [[package]] @@ -7692,7 +7707,7 @@ wheels = [ [[package]] name = "xinference-client" -version = "2.3.1" +version = "2.4.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "aiohttp" }, @@ -7700,9 +7715,9 @@ dependencies = [ { name = "requests" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/bc/7a/33aeef9cffdc331de0046c25412622c5a16226d1b4e0cca9ed512ad00b9a/xinference_client-2.3.1.tar.gz", hash = "sha256:23ae225f47ff9adf4c6f7718c54993d1be8c704d727509f6e5cb670de3e02c4d", size = 58414, upload-time = "2026-03-15T05:53:23.994Z" } +sdist = { url = "https://files.pythonhosted.org/packages/0e/f2/7640528fd4f816df19afe91d52332a658ad2d2bacb13471b0a27dbd0cf46/xinference_client-2.4.0.tar.gz", hash = "sha256:59de6d58f89126c8ff05136818e0756108e534858255d7c4c0673b804fd2d01d", size = 58386, upload-time = "2026-03-29T05:10:58.533Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/74/8d/d9ab0a457718050a279b9bb6515b7245d114118dc5e275f190ef2628dd16/xinference_client-2.3.1-py3-none-any.whl", hash = "sha256:f7c4f0b56635b46be9cfd9b2affa8e15275491597ac9b958e14b13da5745133e", size = 40012, upload-time = "2026-03-15T05:53:22.797Z" }, + { url = "https://files.pythonhosted.org/packages/73/cf/9d27e0095cc28691c73ff186b33556790c7b87f046ca2ecd517c80272592/xinference_client-2.4.0-py3-none-any.whl", hash = "sha256:2f9478b00fe15643f281fe4c0643e74479c8b7837d377000ff120702cda81efc", size = 40012, upload-time = "2026-03-29T05:10:57.279Z" }, ] [[package]] diff --git a/dev/pytest/pytest_unit_tests.sh b/dev/pytest/pytest_unit_tests.sh index a034083304..1d4ff4d86f 100755 --- a/dev/pytest/pytest_unit_tests.sh +++ b/dev/pytest/pytest_unit_tests.sh @@ -13,4 +13,4 @@ PYTEST_XDIST_ARGS="${PYTEST_XDIST_ARGS:--n auto}" pytest --timeout "${PYTEST_TIMEOUT}" ${PYTEST_XDIST_ARGS} api/tests/unit_tests --ignore=api/tests/unit_tests/controllers # Run controller tests sequentially to avoid import race conditions -pytest --timeout "${PYTEST_TIMEOUT}" api/tests/unit_tests/controllers +pytest --timeout "${PYTEST_TIMEOUT}" --cov-append api/tests/unit_tests/controllers diff --git a/docker/.env.example b/docker/.env.example index 8cf77cf56b..9fbf9a9e72 100644 --- a/docker/.env.example +++ b/docker/.env.example @@ -488,7 +488,8 @@ ALIYUN_OSS_REGION=ap-southeast-1 ALIYUN_OSS_AUTH_VERSION=v4 # Don't start with '/'. OSS doesn't support leading slash in object names. ALIYUN_OSS_PATH=your-path -ALIYUN_CLOUDBOX_ID=your-cloudbox-id +# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox. +#ALIYUN_CLOUDBOX_ID=your-cloudbox-id # Tencent COS Configuration # diff --git a/docker/docker-compose-template.yaml b/docker/docker-compose-template.yaml index 04bd2858ff..e55cf942c3 100644 --- a/docker/docker-compose-template.yaml +++ b/docker/docker-compose-template.yaml @@ -21,7 +21,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -63,7 +63,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -102,7 +102,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -132,7 +132,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.2 + image: langgenius/dify-web:1.13.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -245,7 +245,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always environment: # The DifySandbox configurations @@ -269,12 +269,13 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} + DB_SSL_MODE: ${DB_SSL_MODE:-disable} SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002} SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} diff --git a/docker/docker-compose.middleware.yaml b/docker/docker-compose.middleware.yaml index 2dca581903..911da70a73 100644 --- a/docker/docker-compose.middleware.yaml +++ b/docker/docker-compose.middleware.yaml @@ -97,7 +97,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always env_file: - ./middleware.env @@ -123,10 +123,12 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always env_file: - ./middleware.env + extra_hosts: + - "host.docker.internal:host-gateway" environment: # Use the shared environment variables. LOG_OUTPUT_FORMAT: ${LOG_OUTPUT_FORMAT:-text} diff --git a/docker/docker-compose.yaml b/docker/docker-compose.yaml index 6e11cac678..737a62020c 100644 --- a/docker/docker-compose.yaml +++ b/docker/docker-compose.yaml @@ -146,7 +146,6 @@ x-shared-env: &shared-api-worker-env ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1} ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4} ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path} - ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id} TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name} TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key} TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id} @@ -731,7 +730,7 @@ services: # API service api: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -773,7 +772,7 @@ services: # worker service # The Celery worker for processing all queues (dataset, workflow, mail, etc.) worker: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -812,7 +811,7 @@ services: # worker_beat service # Celery beat for scheduling periodic tasks. worker_beat: - image: langgenius/dify-api:1.13.2 + image: langgenius/dify-api:1.13.3 restart: always environment: # Use the shared environment variables. @@ -842,7 +841,7 @@ services: # Frontend web application. web: - image: langgenius/dify-web:1.13.2 + image: langgenius/dify-web:1.13.3 restart: always environment: CONSOLE_API_URL: ${CONSOLE_API_URL:-} @@ -955,7 +954,7 @@ services: # The DifySandbox sandbox: - image: langgenius/dify-sandbox:0.2.12 + image: langgenius/dify-sandbox:0.2.14 restart: always environment: # The DifySandbox configurations @@ -979,12 +978,13 @@ services: # plugin daemon plugin_daemon: - image: langgenius/dify-plugin-daemon:0.5.4-local + image: langgenius/dify-plugin-daemon:0.5.3-local restart: always environment: # Use the shared environment variables. <<: *shared-api-worker-env DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin} + DB_SSL_MODE: ${DB_SSL_MODE:-disable} SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002} SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi} MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800} diff --git a/docker/volumes/sandbox/conf/config.yaml b/docker/volumes/sandbox/conf/config.yaml index 8c1a1deb54..3b4a6b8439 100644 --- a/docker/volumes/sandbox/conf/config.yaml +++ b/docker/volumes/sandbox/conf/config.yaml @@ -5,7 +5,8 @@ app: max_workers: 4 max_requests: 50 worker_timeout: 5 -python_path: /usr/local/bin/python3 +python_path: /opt/python/bin/python3 +nodejs_path: /usr/local/bin/node enable_network: True # please make sure there is no network risk in your environment allowed_syscalls: # please leave it empty if you have no idea how seccomp works proxy: diff --git a/docker/volumes/sandbox/conf/config.yaml.example b/docker/volumes/sandbox/conf/config.yaml.example index f92c19e51a..365089cb9e 100644 --- a/docker/volumes/sandbox/conf/config.yaml.example +++ b/docker/volumes/sandbox/conf/config.yaml.example @@ -5,7 +5,7 @@ app: max_workers: 4 max_requests: 50 worker_timeout: 5 -python_path: /usr/local/bin/python3 +python_path: /opt/python/bin/python3 python_lib_path: - /usr/local/lib/python3.10 - /usr/lib/python3.10 diff --git a/docs/ar-SA/README.md b/docs/ar-SA/README.md index 99e3e3567e..af5a9bfdc6 100644 --- a/docs/ar-SA/README.md +++ b/docs/ar-SA/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

diff --git a/docs/bn-BD/README.md b/docs/bn-BD/README.md index f3fa68b466..5dceacb187 100644 --- a/docs/bn-BD/README.md +++ b/docs/bn-BD/README.md @@ -57,7 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

ডিফাই একটি ওপেন-সোর্স LLM অ্যাপ ডেভেলপমেন্ট প্ল্যাটফর্ম। এটি ইন্টুইটিভ ইন্টারফেস, এজেন্টিক AI ওয়ার্কফ্লো, RAG পাইপলাইন, এজেন্ট ক্যাপাবিলিটি, মডেল ম্যানেজমেন্ট, মনিটরিং সুবিধা এবং আরও অনেক কিছু একত্রিত করে, যা দ্রুত প্রোটোটাইপ থেকে প্রোডাকশন পর্যন্ত নিয়ে যেতে সহায়তা করে। diff --git a/docs/de-DE/README.md b/docs/de-DE/README.md index c71a0bfccf..1eab517a6d 100644 --- a/docs/de-DE/README.md +++ b/docs/de-DE/README.md @@ -57,7 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify ist eine Open-Source-Plattform zur Entwicklung von LLM-Anwendungen. Ihre intuitive Benutzeroberfläche vereint agentenbasierte KI-Workflows, RAG-Pipelines, Agentenfunktionen, Modellverwaltung, Überwachungsfunktionen und mehr, sodass Sie schnell von einem Prototyp in die Produktion übergehen können. diff --git a/docs/es-ES/README.md b/docs/es-ES/README.md index da81b51d6a..f4c60e3d8f 100644 --- a/docs/es-ES/README.md +++ b/docs/es-ES/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/fr-FR/README.md b/docs/fr-FR/README.md index 291c8dab40..db8730b36b 100644 --- a/docs/fr-FR/README.md +++ b/docs/fr-FR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/hi-IN/README.md b/docs/hi-IN/README.md index bedeaa6246..ad712046b5 100644 --- a/docs/hi-IN/README.md +++ b/docs/hi-IN/README.md @@ -58,6 +58,8 @@ README Tiếng Việt README in Deutsch README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা README in हिन्दी

diff --git a/docs/it-IT/README.md b/docs/it-IT/README.md index 2e96335d3e..bca560b574 100644 --- a/docs/it-IT/README.md +++ b/docs/it-IT/README.md @@ -58,7 +58,10 @@ README Tiếng Việt README in Deutsch README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify è una piattaforma open-source per lo sviluppo di applicazioni LLM. La sua interfaccia intuitiva combina flussi di lavoro AI basati su agenti, pipeline RAG, funzionalità di agenti, gestione dei modelli, funzionalità di monitoraggio e altro ancora, permettendovi di passare rapidamente da un prototipo alla produzione. diff --git a/docs/ja-JP/README.md b/docs/ja-JP/README.md index 659ffbda51..298dcb95aa 100644 --- a/docs/ja-JP/README.md +++ b/docs/ja-JP/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/ko-KR/README.md b/docs/ko-KR/README.md index 2f6c526ef2..2dcacaae8b 100644 --- a/docs/ko-KR/README.md +++ b/docs/ko-KR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch - README in বাংলা + README in Italiano + README em Português do Brasil + README Slovenščina + README in বাংলা + README in हिन्दी

Dify는 오픈 소스 LLM 앱 개발 플랫폼입니다. 직관적인 인터페이스를 통해 AI 워크플로우, RAG 파이프라인, 에이전트 기능, 모델 관리, 관찰 기능 등을 결합하여 프로토타입에서 프로덕션까지 빠르게 전환할 수 있습니다. 주요 기능 목록은 다음과 같습니다:

diff --git a/docs/pt-BR/README.md b/docs/pt-BR/README.md index ed29ec0294..f7cc37a20f 100644 --- a/docs/pt-BR/README.md +++ b/docs/pt-BR/README.md @@ -58,7 +58,10 @@ README em Vietnamita README em Português - BR README in Deutsch + README in Italiano + README Slovenščina README in বাংলা + README in हिन्दी

Dify é uma plataforma de desenvolvimento de aplicativos LLM de código aberto. Sua interface intuitiva combina workflow de IA, pipeline RAG, capacidades de agente, gerenciamento de modelos, recursos de observabilidade e muito mais, permitindo que você vá rapidamente do protótipo à produção. Aqui está uma lista das principais funcionalidades: diff --git a/docs/sl-SI/README.md b/docs/sl-SI/README.md index caef2c303c..7b3fe76b5d 100644 --- a/docs/sl-SI/README.md +++ b/docs/sl-SI/README.md @@ -53,9 +53,12 @@ README بالعربية Türkçe README README Tiếng Việt - README Slovenščina README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify je odprtokodna platforma za razvoj aplikacij LLM. Njegov intuitivni vmesnik združuje agentski potek dela z umetno inteligenco, cevovod RAG, zmogljivosti agentov, upravljanje modelov, funkcije opazovanja in več, kar vam omogoča hiter prehod od prototipa do proizvodnje. diff --git a/docs/tlh/README.md b/docs/tlh/README.md index e2acd7734c..a8e63026c8 100644 --- a/docs/tlh/README.md +++ b/docs/tlh/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

# diff --git a/docs/tr-TR/README.md b/docs/tr-TR/README.md index 6361ca5dd9..cecc1b189c 100644 --- a/docs/tr-TR/README.md +++ b/docs/tr-TR/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify, açık kaynaklı bir LLM uygulama geliştirme platformudur. Sezgisel arayüzü, AI iş akışı, RAG pipeline'ı, ajan yetenekleri, model yönetimi, gözlemlenebilirlik özellikleri ve daha fazlasını birleştirerek, prototipten üretime hızlıca geçmenizi sağlar. İşte temel özelliklerin bir listesi: diff --git a/docs/vi-VN/README.md b/docs/vi-VN/README.md index 3042a98d95..5230d53110 100644 --- a/docs/vi-VN/README.md +++ b/docs/vi-VN/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी

Dify là một nền tảng phát triển ứng dụng LLM mã nguồn mở. Giao diện trực quan kết hợp quy trình làm việc AI, mô hình RAG, khả năng tác nhân, quản lý mô hình, tính năng quan sát và hơn thế nữa, cho phép bạn nhanh chóng chuyển từ nguyên mẫu sang sản phẩm. Đây là danh sách các tính năng cốt lõi: diff --git a/docs/zh-CN/README.md b/docs/zh-CN/README.md index 15bb447ad8..8ba8009959 100644 --- a/docs/zh-CN/README.md +++ b/docs/zh-CN/README.md @@ -53,7 +53,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina README in বাংলা + README in हिन्दी
# diff --git a/docs/zh-TW/README.md b/docs/zh-TW/README.md index 14b343ba29..de5bab8679 100644 --- a/docs/zh-TW/README.md +++ b/docs/zh-TW/README.md @@ -57,6 +57,11 @@ Türkçe README README Tiếng Việt README in Deutsch + README in Italiano + README em Português do Brasil + README Slovenščina + README in বাংলা + README in हिन्दी

Dify 是一個開源的 LLM 應用程式開發平台。其直觀的界面結合了智能代理工作流程、RAG 管道、代理功能、模型管理、可觀察性功能等,讓您能夠快速從原型進展到生產環境。 diff --git a/e2e/.gitignore b/e2e/.gitignore new file mode 100644 index 0000000000..96c1e0f3a1 --- /dev/null +++ b/e2e/.gitignore @@ -0,0 +1,6 @@ +node_modules/ +.auth/ +playwright-report/ +test-results/ +cucumber-report/ +.logs/ diff --git a/e2e/AGENTS.md b/e2e/AGENTS.md new file mode 100644 index 0000000000..245c9863d4 --- /dev/null +++ b/e2e/AGENTS.md @@ -0,0 +1,164 @@ +# E2E + +This package contains the repository-level end-to-end tests for Dify. + +This file is the canonical package guide for `e2e/`. Keep detailed workflow, architecture, debugging, and reporting documentation here. Keep `README.md` as a minimal pointer to this file so the two documents do not drift. + +The suite uses Cucumber for scenario definitions and Playwright as the browser execution layer. + +It tests: + +- backend API started from source +- frontend served from the production artifact +- middleware services started from Docker + +## Prerequisites + +- Node.js `^22.22.1` +- `pnpm` +- `uv` +- Docker + +Install Playwright browsers once: + +```bash +cd e2e +pnpm install +pnpm e2e:install +pnpm check +``` + +Use `pnpm check` as the default local verification step after editing E2E TypeScript, Cucumber support code, or feature glue. It runs formatting, linting, and type checks for this package. + +Common commands: + +```bash +# authenticated-only regression (default excludes @fresh) +# expects backend API, frontend artifact, and middleware stack to already be running +pnpm e2e + +# full reset + fresh install + authenticated scenarios +# starts required middleware/dependencies for you +pnpm e2e:full + +# run a tagged subset +pnpm e2e -- --tags @smoke + +# headed browser +pnpm e2e:headed -- --tags @smoke + +# slow down browser actions for local debugging +E2E_SLOW_MO=500 pnpm e2e:headed -- --tags @smoke +``` + +Frontend artifact behavior: + +- if `web/.next/BUILD_ID` exists, E2E reuses the existing build by default +- if you set `E2E_FORCE_WEB_BUILD=1`, E2E rebuilds the frontend before starting it + +## Lifecycle + +```mermaid +flowchart TD + A["Start E2E run"] --> B["run-cucumber.ts orchestrates setup/API/frontend"] + B --> C["support/web-server.ts starts or reuses frontend directly"] + C --> D["Cucumber loads config, steps, and support modules"] + D --> E["BeforeAll bootstraps shared auth state via /install"] + E --> F{"Which command is running?"} + F -->|`pnpm e2e`| G["Run config default tags: not @fresh and not @skip"] + F -->|`pnpm e2e:full*`| H["Override tags to not @skip"] + G --> I["Per-scenario BrowserContext from shared browser"] + H --> I + I --> J["Failure artifacts written to cucumber-report/artifacts"] +``` + +Ownership is split like this: + +- `scripts/setup.ts` is the single environment entrypoint for reset, middleware, backend, and frontend startup +- `run-cucumber.ts` orchestrates the E2E run and Cucumber invocation +- `support/web-server.ts` manages frontend reuse, startup, readiness, and shutdown +- `features/support/hooks.ts` manages auth bootstrap, scenario lifecycle, and diagnostics +- `features/support/world.ts` owns per-scenario typed context +- `features/step-definitions/` holds domain-oriented glue so the official VS Code Cucumber plugin works with default conventions when `e2e/` is opened as the workspace root + +Package layout: + +- `features/`: Gherkin scenarios grouped by capability +- `features/step-definitions/`: domain-oriented step definitions +- `features/support/hooks.ts`: suite lifecycle, auth-state bootstrap, diagnostics +- `features/support/world.ts`: shared scenario context +- `support/web-server.ts`: typed frontend startup/reuse logic +- `scripts/setup.ts`: reset and service lifecycle commands +- `scripts/run-cucumber.ts`: Cucumber orchestration entrypoint + +Behavior depends on instance state: + +- uninitialized instance: completes install and stores authenticated state +- initialized instance: signs in and reuses authenticated state + +Because of that, the `@fresh` install scenario only runs in the `pnpm e2e:full*` flows. The default `pnpm e2e*` flows exclude `@fresh` via Cucumber config tags so they can be re-run against an already initialized instance. + +Reset all persisted E2E state: + +```bash +pnpm e2e:reset +``` + +This removes: + +- `docker/volumes/db/data` +- `docker/volumes/redis/data` +- `docker/volumes/weaviate` +- `docker/volumes/plugin_daemon` +- `e2e/.auth` +- `e2e/.logs` +- `e2e/cucumber-report` + +Start the full middleware stack: + +```bash +pnpm e2e:middleware:up +``` + +Stop the full middleware stack: + +```bash +pnpm e2e:middleware:down +``` + +The middleware stack includes: + +- PostgreSQL +- Redis +- Weaviate +- Sandbox +- SSRF proxy +- Plugin daemon + +Fresh install verification: + +```bash +pnpm e2e:full +``` + +Run the Cucumber suite against an already running middleware stack: + +```bash +pnpm e2e:middleware:up +pnpm e2e +pnpm e2e:middleware:down +``` + +Artifacts and diagnostics: + +- `cucumber-report/report.html`: HTML report +- `cucumber-report/report.json`: JSON report +- `cucumber-report/artifacts/`: failure screenshots and HTML captures +- `.logs/cucumber-api.log`: backend startup log +- `.logs/cucumber-web.log`: frontend startup log + +Open the HTML report locally with: + +```bash +open cucumber-report/report.html +``` diff --git a/e2e/README.md b/e2e/README.md new file mode 100644 index 0000000000..9b4046eaff --- /dev/null +++ b/e2e/README.md @@ -0,0 +1,3 @@ +# E2E + +Canonical documentation for this package lives in [AGENTS.md](./AGENTS.md). diff --git a/e2e/cucumber.config.ts b/e2e/cucumber.config.ts new file mode 100644 index 0000000000..c162a6562e --- /dev/null +++ b/e2e/cucumber.config.ts @@ -0,0 +1,19 @@ +import type { IConfiguration } from '@cucumber/cucumber' + +const config = { + format: [ + 'progress-bar', + 'summary', + 'html:./cucumber-report/report.html', + 'json:./cucumber-report/report.json', + ], + import: ['features/**/*.ts'], + parallel: 1, + paths: ['features/**/*.feature'], + tags: process.env.E2E_CUCUMBER_TAGS || 'not @fresh and not @skip', + timeout: 60_000, +} satisfies Partial & { + timeout: number +} + +export default config diff --git a/e2e/features/apps/create-app.feature b/e2e/features/apps/create-app.feature new file mode 100644 index 0000000000..c0ca8ea4e0 --- /dev/null +++ b/e2e/features/apps/create-app.feature @@ -0,0 +1,10 @@ +@apps @authenticated +Feature: Create app + Scenario: Create a new blank app and redirect to the editor + Given I am signed in as the default E2E admin + When I open the apps console + And I start creating a blank app + And I enter a unique E2E app name + And I confirm app creation + Then I should land on the app editor + And I should see the "Orchestrate" text diff --git a/e2e/features/smoke/authenticated-entry.feature b/e2e/features/smoke/authenticated-entry.feature new file mode 100644 index 0000000000..3c1191a330 --- /dev/null +++ b/e2e/features/smoke/authenticated-entry.feature @@ -0,0 +1,8 @@ +@smoke @authenticated +Feature: Authenticated app console + Scenario: Open the apps console with the shared authenticated state + Given I am signed in as the default E2E admin + When I open the apps console + Then I should stay on the apps console + And I should see the "Create from Blank" button + And I should not see the "Sign in" button diff --git a/e2e/features/smoke/install.feature b/e2e/features/smoke/install.feature new file mode 100644 index 0000000000..39fc1f996b --- /dev/null +++ b/e2e/features/smoke/install.feature @@ -0,0 +1,7 @@ +@smoke @fresh +Feature: Fresh installation bootstrap + Scenario: Complete the initial installation bootstrap on a fresh instance + Given the last authentication bootstrap came from a fresh install + When I open the apps console + Then I should stay on the apps console + And I should see the "Create from Blank" button diff --git a/e2e/features/step-definitions/apps/create-app.steps.ts b/e2e/features/step-definitions/apps/create-app.steps.ts new file mode 100644 index 0000000000..b8e76c6f06 --- /dev/null +++ b/e2e/features/step-definitions/apps/create-app.steps.ts @@ -0,0 +1,29 @@ +import { Then, When } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +When('I start creating a blank app', async function (this: DifyWorld) { + const page = this.getPage() + + await expect(page.getByRole('button', { name: 'Create from Blank' })).toBeVisible() + await page.getByRole('button', { name: 'Create from Blank' }).click() +}) + +When('I enter a unique E2E app name', async function (this: DifyWorld) { + const appName = `E2E App ${Date.now()}` + + await this.getPage().getByPlaceholder('Give your app a name').fill(appName) +}) + +When('I confirm app creation', async function (this: DifyWorld) { + const createButton = this.getPage() + .getByRole('button', { name: /^Create(?:\s|$)/ }) + .last() + + await expect(createButton).toBeEnabled() + await createButton.click() +}) + +Then('I should land on the app editor', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/app\/[^/]+\/(workflow|configuration)(?:\?.*)?$/) +}) diff --git a/e2e/features/step-definitions/common/auth.steps.ts b/e2e/features/step-definitions/common/auth.steps.ts new file mode 100644 index 0000000000..bf03c2d8f4 --- /dev/null +++ b/e2e/features/step-definitions/common/auth.steps.ts @@ -0,0 +1,11 @@ +import { Given } from '@cucumber/cucumber' +import type { DifyWorld } from '../../support/world' + +Given('I am signed in as the default E2E admin', async function (this: DifyWorld) { + const session = await this.getAuthSession() + + this.attach( + `Authenticated as ${session.adminEmail} using ${session.mode} flow at ${session.baseURL}.`, + 'text/plain', + ) +}) diff --git a/e2e/features/step-definitions/common/navigation.steps.ts b/e2e/features/step-definitions/common/navigation.steps.ts new file mode 100644 index 0000000000..b18ff035fa --- /dev/null +++ b/e2e/features/step-definitions/common/navigation.steps.ts @@ -0,0 +1,23 @@ +import { Then, When } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +When('I open the apps console', async function (this: DifyWorld) { + await this.getPage().goto('/apps') +}) + +Then('I should stay on the apps console', async function (this: DifyWorld) { + await expect(this.getPage()).toHaveURL(/\/apps(?:\?.*)?$/) +}) + +Then('I should see the {string} button', async function (this: DifyWorld, label: string) { + await expect(this.getPage().getByRole('button', { name: label })).toBeVisible() +}) + +Then('I should not see the {string} button', async function (this: DifyWorld, label: string) { + await expect(this.getPage().getByRole('button', { name: label })).not.toBeVisible() +}) + +Then('I should see the {string} text', async function (this: DifyWorld, text: string) { + await expect(this.getPage().getByText(text)).toBeVisible({ timeout: 30_000 }) +}) diff --git a/e2e/features/step-definitions/smoke/install.steps.ts b/e2e/features/step-definitions/smoke/install.steps.ts new file mode 100644 index 0000000000..857e01a971 --- /dev/null +++ b/e2e/features/step-definitions/smoke/install.steps.ts @@ -0,0 +1,12 @@ +import { Given } from '@cucumber/cucumber' +import { expect } from '@playwright/test' +import type { DifyWorld } from '../../support/world' + +Given( + 'the last authentication bootstrap came from a fresh install', + async function (this: DifyWorld) { + const session = await this.getAuthSession() + + expect(session.mode).toBe('install') + }, +) diff --git a/e2e/features/support/hooks.ts b/e2e/features/support/hooks.ts new file mode 100644 index 0000000000..a6862d79f5 --- /dev/null +++ b/e2e/features/support/hooks.ts @@ -0,0 +1,90 @@ +import { After, AfterAll, Before, BeforeAll, Status, setDefaultTimeout } from '@cucumber/cucumber' +import { chromium, type Browser } from '@playwright/test' +import { mkdir, writeFile } from 'node:fs/promises' +import path from 'node:path' +import { fileURLToPath } from 'node:url' +import { ensureAuthenticatedState } from '../../fixtures/auth' +import { baseURL, cucumberHeadless, cucumberSlowMo } from '../../test-env' +import type { DifyWorld } from './world' + +const e2eRoot = fileURLToPath(new URL('../..', import.meta.url)) +const artifactsDir = path.join(e2eRoot, 'cucumber-report', 'artifacts') + +let browser: Browser | undefined + +setDefaultTimeout(60_000) + +const sanitizeForPath = (value: string) => + value.replaceAll(/[^a-zA-Z0-9_-]+/g, '-').replaceAll(/^-+|-+$/g, '') + +const writeArtifact = async ( + scenarioName: string, + extension: 'html' | 'png', + contents: Buffer | string, +) => { + const artifactPath = path.join( + artifactsDir, + `${Date.now()}-${sanitizeForPath(scenarioName || 'scenario')}.${extension}`, + ) + await writeFile(artifactPath, contents) + + return artifactPath +} + +BeforeAll(async () => { + await mkdir(artifactsDir, { recursive: true }) + + browser = await chromium.launch({ + headless: cucumberHeadless, + slowMo: cucumberSlowMo, + }) + + console.log(`[e2e] session cache bootstrap against ${baseURL}`) + await ensureAuthenticatedState(browser, baseURL) +}) + +Before(async function (this: DifyWorld, { pickle }) { + if (!browser) throw new Error('Shared Playwright browser is not available.') + + await this.startAuthenticatedSession(browser) + this.scenarioStartedAt = Date.now() + + const tags = pickle.tags.map((tag) => tag.name).join(' ') + console.log(`[e2e] start ${pickle.name}${tags ? ` ${tags}` : ''}`) +}) + +After(async function (this: DifyWorld, { pickle, result }) { + const elapsedMs = this.scenarioStartedAt ? Date.now() - this.scenarioStartedAt : undefined + + if (result?.status !== Status.PASSED && this.page) { + const screenshot = await this.page.screenshot({ + fullPage: true, + }) + const screenshotPath = await writeArtifact(pickle.name, 'png', screenshot) + this.attach(screenshot, 'image/png') + + const html = await this.page.content() + const htmlPath = await writeArtifact(pickle.name, 'html', html) + this.attach(html, 'text/html') + + if (this.consoleErrors.length > 0) + this.attach(`Console Errors:\n${this.consoleErrors.join('\n')}`, 'text/plain') + + if (this.pageErrors.length > 0) + this.attach(`Page Errors:\n${this.pageErrors.join('\n')}`, 'text/plain') + + this.attach(`Artifacts:\n${[screenshotPath, htmlPath].join('\n')}`, 'text/plain') + } + + const status = result?.status || 'UNKNOWN' + console.log( + `[e2e] end ${pickle.name} status=${status}${elapsedMs ? ` durationMs=${elapsedMs}` : ''}`, + ) + + await this.closeSession() +}) + +AfterAll(async () => { + await browser?.close() + browser = undefined +}) diff --git a/e2e/features/support/world.ts b/e2e/features/support/world.ts new file mode 100644 index 0000000000..15ab8daf16 --- /dev/null +++ b/e2e/features/support/world.ts @@ -0,0 +1,68 @@ +import { type IWorldOptions, World, setWorldConstructor } from '@cucumber/cucumber' +import type { Browser, BrowserContext, ConsoleMessage, Page } from '@playwright/test' +import { + authStatePath, + readAuthSessionMetadata, + type AuthSessionMetadata, +} from '../../fixtures/auth' +import { baseURL, defaultLocale } from '../../test-env' + +export class DifyWorld extends World { + context: BrowserContext | undefined + page: Page | undefined + consoleErrors: string[] = [] + pageErrors: string[] = [] + scenarioStartedAt: number | undefined + session: AuthSessionMetadata | undefined + + constructor(options: IWorldOptions) { + super(options) + this.resetScenarioState() + } + + resetScenarioState() { + this.consoleErrors = [] + this.pageErrors = [] + } + + async startAuthenticatedSession(browser: Browser) { + this.resetScenarioState() + this.context = await browser.newContext({ + baseURL, + locale: defaultLocale, + storageState: authStatePath, + }) + this.context.setDefaultTimeout(30_000) + this.page = await this.context.newPage() + this.page.setDefaultTimeout(30_000) + + this.page.on('console', (message: ConsoleMessage) => { + if (message.type() === 'error') this.consoleErrors.push(message.text()) + }) + this.page.on('pageerror', (error) => { + this.pageErrors.push(error.message) + }) + } + + getPage() { + if (!this.page) throw new Error('Playwright page has not been initialized for this scenario.') + + return this.page + } + + async getAuthSession() { + this.session ??= await readAuthSessionMetadata() + return this.session + } + + async closeSession() { + await this.context?.close() + this.context = undefined + this.page = undefined + this.session = undefined + this.scenarioStartedAt = undefined + this.resetScenarioState() + } +} + +setWorldConstructor(DifyWorld) diff --git a/e2e/fixtures/auth.ts b/e2e/fixtures/auth.ts new file mode 100644 index 0000000000..853bfff5ed --- /dev/null +++ b/e2e/fixtures/auth.ts @@ -0,0 +1,148 @@ +import type { Browser, Page } from '@playwright/test' +import { expect } from '@playwright/test' +import { mkdir, readFile, writeFile } from 'node:fs/promises' +import path from 'node:path' +import { fileURLToPath } from 'node:url' +import { defaultBaseURL, defaultLocale } from '../test-env' + +export type AuthSessionMetadata = { + adminEmail: string + baseURL: string + mode: 'install' | 'login' + usedInitPassword: boolean +} + +const WAIT_TIMEOUT_MS = 120_000 +const e2eRoot = fileURLToPath(new URL('..', import.meta.url)) + +export const authDir = path.join(e2eRoot, '.auth') +export const authStatePath = path.join(authDir, 'admin.json') +export const authMetadataPath = path.join(authDir, 'session.json') + +export const adminCredentials = { + email: process.env.E2E_ADMIN_EMAIL || 'e2e-admin@example.com', + name: process.env.E2E_ADMIN_NAME || 'E2E Admin', + password: process.env.E2E_ADMIN_PASSWORD || 'E2eAdmin12345', +} + +const initPassword = process.env.E2E_INIT_PASSWORD || 'E2eInit12345' + +export const resolveBaseURL = (configuredBaseURL?: string) => + configuredBaseURL || process.env.E2E_BASE_URL || defaultBaseURL + +export const readAuthSessionMetadata = async () => { + const content = await readFile(authMetadataPath, 'utf8') + return JSON.parse(content) as AuthSessionMetadata +} + +const escapeRegex = (value: string) => value.replaceAll(/[.*+?^${}()|[\]\\]/g, '\\$&') + +const appURL = (baseURL: string, pathname: string) => new URL(pathname, baseURL).toString() + +const waitForPageState = async (page: Page) => { + const installHeading = page.getByRole('heading', { name: 'Setting up an admin account' }) + const signInButton = page.getByRole('button', { name: 'Sign in' }) + const initPasswordField = page.getByLabel('Admin initialization password') + + const deadline = Date.now() + WAIT_TIMEOUT_MS + + while (Date.now() < deadline) { + if (await installHeading.isVisible().catch(() => false)) return 'install' as const + if (await signInButton.isVisible().catch(() => false)) return 'login' as const + if (await initPasswordField.isVisible().catch(() => false)) return 'init' as const + + await page.waitForTimeout(1_000) + } + + throw new Error(`Unable to determine auth page state for ${page.url()}`) +} + +const completeInitPasswordIfNeeded = async (page: Page) => { + const initPasswordField = page.getByLabel('Admin initialization password') + if (!(await initPasswordField.isVisible({ timeout: 3_000 }).catch(() => false))) return false + + await initPasswordField.fill(initPassword) + await page.getByRole('button', { name: 'Validate' }).click() + await expect(page.getByRole('heading', { name: 'Setting up an admin account' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + return true +} + +const completeInstall = async (page: Page, baseURL: string) => { + await expect(page.getByRole('heading', { name: 'Setting up an admin account' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await page.getByLabel('Email address').fill(adminCredentials.email) + await page.getByLabel('Username').fill(adminCredentials.name) + await page.getByLabel('Password').fill(adminCredentials.password) + await page.getByRole('button', { name: 'Set up' }).click() + + await expect(page).toHaveURL(new RegExp(`^${escapeRegex(baseURL)}/apps(?:\\?.*)?$`), { + timeout: WAIT_TIMEOUT_MS, + }) +} + +const completeLogin = async (page: Page, baseURL: string) => { + await expect(page.getByRole('button', { name: 'Sign in' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await page.getByLabel('Email address').fill(adminCredentials.email) + await page.getByLabel('Password').fill(adminCredentials.password) + await page.getByRole('button', { name: 'Sign in' }).click() + + await expect(page).toHaveURL(new RegExp(`^${escapeRegex(baseURL)}/apps(?:\\?.*)?$`), { + timeout: WAIT_TIMEOUT_MS, + }) +} + +export const ensureAuthenticatedState = async (browser: Browser, configuredBaseURL?: string) => { + const baseURL = resolveBaseURL(configuredBaseURL) + + await mkdir(authDir, { recursive: true }) + + const context = await browser.newContext({ + baseURL, + locale: defaultLocale, + }) + const page = await context.newPage() + + try { + await page.goto(appURL(baseURL, '/install'), { waitUntil: 'networkidle' }) + + let usedInitPassword = await completeInitPasswordIfNeeded(page) + let pageState = await waitForPageState(page) + + while (pageState === 'init') { + const completedInitPassword = await completeInitPasswordIfNeeded(page) + if (!completedInitPassword) + throw new Error(`Unable to validate initialization password for ${page.url()}`) + + usedInitPassword = true + pageState = await waitForPageState(page) + } + + if (pageState === 'install') await completeInstall(page, baseURL) + else await completeLogin(page, baseURL) + + await expect(page.getByRole('button', { name: 'Create from Blank' })).toBeVisible({ + timeout: WAIT_TIMEOUT_MS, + }) + + await context.storageState({ path: authStatePath }) + + const metadata: AuthSessionMetadata = { + adminEmail: adminCredentials.email, + baseURL, + mode: pageState, + usedInitPassword, + } + + await writeFile(authMetadataPath, `${JSON.stringify(metadata, null, 2)}\n`, 'utf8') + } finally { + await context.close() + } +} diff --git a/e2e/package.json b/e2e/package.json new file mode 100644 index 0000000000..9b8a1f873f --- /dev/null +++ b/e2e/package.json @@ -0,0 +1,34 @@ +{ + "name": "dify-e2e", + "private": true, + "type": "module", + "scripts": { + "check": "vp check --fix", + "e2e": "tsx ./scripts/run-cucumber.ts", + "e2e:full": "tsx ./scripts/run-cucumber.ts --full", + "e2e:full:headed": "tsx ./scripts/run-cucumber.ts --full --headed", + "e2e:headed": "tsx ./scripts/run-cucumber.ts --headed", + "e2e:install": "playwright install --with-deps chromium", + "e2e:middleware:down": "tsx ./scripts/setup.ts middleware-down", + "e2e:middleware:up": "tsx ./scripts/setup.ts middleware-up", + "e2e:reset": "tsx ./scripts/setup.ts reset" + }, + "devDependencies": { + "@cucumber/cucumber": "12.7.0", + "@playwright/test": "1.51.1", + "@types/node": "25.5.0", + "tsx": "4.21.0", + "typescript": "5.9.3", + "vite-plus": "latest" + }, + "engines": { + "node": "^22.22.1" + }, + "packageManager": "pnpm@10.32.1", + "pnpm": { + "overrides": { + "vite": "npm:@voidzero-dev/vite-plus-core@latest", + "vitest": "npm:@voidzero-dev/vite-plus-test@latest" + } + } +} diff --git a/e2e/pnpm-lock.yaml b/e2e/pnpm-lock.yaml new file mode 100644 index 0000000000..b63458ad4a --- /dev/null +++ b/e2e/pnpm-lock.yaml @@ -0,0 +1,2632 @@ +lockfileVersion: '9.0' + +settings: + autoInstallPeers: true + excludeLinksFromLockfile: false + +overrides: + vite: npm:@voidzero-dev/vite-plus-core@latest + vitest: npm:@voidzero-dev/vite-plus-test@latest + +importers: + + .: + devDependencies: + '@cucumber/cucumber': + specifier: 12.7.0 + version: 12.7.0 + '@playwright/test': + specifier: 1.51.1 + version: 1.51.1 + '@types/node': + specifier: 25.5.0 + version: 25.5.0 + tsx: + specifier: 4.21.0 + version: 4.21.0 + typescript: + specifier: 5.9.3 + version: 5.9.3 + vite-plus: + specifier: latest + version: 0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + +packages: + + '@babel/code-frame@7.29.0': + resolution: {integrity: sha512-9NhCeYjq9+3uxgdtp20LSiJXJvN0FeCtNGpJxuMFZ1Kv3cWUNb6DOhJwUvcVCzKGR66cw4njwM6hrJLqgOwbcw==} + engines: {node: '>=6.9.0'} + + '@babel/helper-validator-identifier@7.28.5': + resolution: {integrity: sha512-qSs4ifwzKJSV39ucNjsvc6WVHs6b7S03sOh2OcHF9UHfVPqWWALUsNUVzhSBiItjRZoLHx7nIarVjqKVusUZ1Q==} + engines: {node: '>=6.9.0'} + + '@colors/colors@1.5.0': + resolution: {integrity: sha512-ooWCrlZP11i8GImSjTHYHLkvFDP48nS4+204nGb1RiX/WXYHmJA2III9/e2DWVabCESdW7hBAEzHRqUn9OUVvQ==} + engines: {node: '>=0.1.90'} + + '@cucumber/ci-environment@13.0.0': + resolution: {integrity: sha512-cs+3NzfNkGbcmHPddjEv4TKFiBpZRQ6WJEEufB9mw+ExS22V/4R/zpDSEG+fsJ/iSNCd6A2sATdY8PFOyY3YnA==} + + '@cucumber/cucumber-expressions@19.0.0': + resolution: {integrity: sha512-4FKoOQh2Uf6F6/Ln+1OxuK8LkTg6PyAqekhf2Ix8zqV2M54sH+m7XNJNLhOFOAW/t9nxzRbw2CcvXbCLjcvHZg==} + + '@cucumber/cucumber@12.7.0': + resolution: {integrity: sha512-7A/9CJpJDxv1SQ7hAZU0zPn2yRxx6XMR+LO4T94Enm3cYNWsEEj+RGX38NLX4INT+H6w5raX3Csb/qs4vUBsOA==} + engines: {node: 20 || 22 || >=24} + hasBin: true + + '@cucumber/gherkin-streams@6.0.0': + resolution: {integrity: sha512-HLSHMmdDH0vCr7vsVEURcDA4WwnRLdjkhqr6a4HQ3i4RFK1wiDGPjBGVdGJLyuXuRdJpJbFc6QxHvT8pU4t6jw==} + hasBin: true + peerDependencies: + '@cucumber/gherkin': '>=22.0.0' + '@cucumber/message-streams': '>=4.0.0' + '@cucumber/messages': '>=17.1.1' + + '@cucumber/gherkin-utils@11.0.0': + resolution: {integrity: sha512-LJ+s4+TepHTgdKWDR4zbPyT7rQjmYIcukTwNbwNwgqr6i8Gjcmzf6NmtbYDA19m1ZFg6kWbFsmHnj37ZuX+kZA==} + hasBin: true + + '@cucumber/gherkin@38.0.0': + resolution: {integrity: sha512-duEXK+KDfQUzu3vsSzXjkxQ2tirF5PRsc1Xrts6THKHJO6mjw4RjM8RV+vliuDasmhhrmdLcOcM7d9nurNTJKw==} + + '@cucumber/html-formatter@23.0.0': + resolution: {integrity: sha512-WwcRzdM8Ixy4e53j+Frm3fKM5rNuIyWUfy4HajEN+Xk/YcjA6yW0ACGTFDReB++VDZz/iUtwYdTlPRY36NbqJg==} + peerDependencies: + '@cucumber/messages': '>=18' + + '@cucumber/junit-xml-formatter@0.9.0': + resolution: {integrity: sha512-WF+A7pBaXpKMD1i7K59Nk5519zj4extxY4+4nSgv5XLsGXHDf1gJnb84BkLUzevNtp2o2QzMG0vWLwSm8V5blw==} + peerDependencies: + '@cucumber/messages': '*' + + '@cucumber/message-streams@4.0.1': + resolution: {integrity: sha512-Kxap9uP5jD8tHUZVjTWgzxemi/0uOsbGjd4LBOSxcJoOCRbESFwemUzilJuzNTB8pcTQUh8D5oudUyxfkJOKmA==} + peerDependencies: + '@cucumber/messages': '>=17.1.1' + + '@cucumber/messages@32.0.1': + resolution: {integrity: sha512-1OSoW+GQvFUNAl6tdP2CTBexTXMNJF0094goVUcvugtQeXtJ0K8sCP0xbq7GGoiezs/eJAAOD03+zAPT64orHQ==} + + '@cucumber/pretty-formatter@1.0.1': + resolution: {integrity: sha512-A1lU4VVP0aUWdOTmpdzvXOyEYuPtBDI0xYwYJnmoMDplzxMdhcHk86lyyvYDoMoPzzq6OkOE3isuosvUU4X7IQ==} + peerDependencies: + '@cucumber/cucumber': '>=7.0.0' + '@cucumber/messages': '*' + + '@cucumber/query@14.7.0': + resolution: {integrity: sha512-fiqZ4gMEgYjmbuWproF/YeCdD5y+gD2BqgBIGbpihOsx6UlNsyzoDSfO+Tny0q65DxfK+pHo2UkPyEl7dO7wmQ==} + peerDependencies: + '@cucumber/messages': '*' + + '@cucumber/tag-expressions@9.1.0': + resolution: {integrity: sha512-bvHjcRFZ+J1TqIa9eFNO1wGHqwx4V9ZKV3hYgkuK/VahHx73uiP4rKV3JVrvWSMrwrFvJG6C8aEwnCWSvbyFdQ==} + + '@emnapi/core@1.9.1': + resolution: {integrity: sha512-mukuNALVsoix/w1BJwFzwXBN/dHeejQtuVzcDsfOEsdpCumXb/E9j8w11h5S54tT1xhifGfbbSm/ICrObRb3KA==} + + '@emnapi/runtime@1.9.1': + resolution: {integrity: sha512-VYi5+ZVLhpgK4hQ0TAjiQiZ6ol0oe4mBx7mVv7IflsiEp0OWoVsp/+f9Vc1hOhE0TtkORVrI1GvzyreqpgWtkA==} + + '@emnapi/wasi-threads@1.2.0': + resolution: {integrity: sha512-N10dEJNSsUx41Z6pZsXU8FjPjpBEplgH24sfkmITrBED1/U2Esum9F3lfLrMjKHHjmi557zQn7kR9R+XWXu5Rg==} + + '@esbuild/aix-ppc64@0.27.4': + resolution: {integrity: sha512-cQPwL2mp2nSmHHJlCyoXgHGhbEPMrEEU5xhkcy3Hs/O7nGZqEpZ2sUtLaL9MORLtDfRvVl2/3PAuEkYZH0Ty8Q==} + engines: {node: '>=18'} + cpu: [ppc64] + os: [aix] + + '@esbuild/android-arm64@0.27.4': + resolution: {integrity: sha512-gdLscB7v75wRfu7QSm/zg6Rx29VLdy9eTr2t44sfTW7CxwAtQghZ4ZnqHk3/ogz7xao0QAgrkradbBzcqFPasw==} + engines: {node: '>=18'} + cpu: [arm64] + os: [android] + + '@esbuild/android-arm@0.27.4': + resolution: {integrity: sha512-X9bUgvxiC8CHAGKYufLIHGXPJWnr0OCdR0anD2e21vdvgCI8lIfqFbnoeOz7lBjdrAGUhqLZLcQo6MLhTO2DKQ==} + engines: {node: '>=18'} + cpu: [arm] + os: [android] + + '@esbuild/android-x64@0.27.4': + resolution: {integrity: sha512-PzPFnBNVF292sfpfhiyiXCGSn9HZg5BcAz+ivBuSsl6Rk4ga1oEXAamhOXRFyMcjwr2DVtm40G65N3GLeH1Lvw==} + engines: {node: '>=18'} + cpu: [x64] + os: [android] + + '@esbuild/darwin-arm64@0.27.4': + resolution: {integrity: sha512-b7xaGIwdJlht8ZFCvMkpDN6uiSmnxxK56N2GDTMYPr2/gzvfdQN8rTfBsvVKmIVY/X7EM+/hJKEIbbHs9oA4tQ==} + engines: {node: '>=18'} + cpu: [arm64] + os: [darwin] + + '@esbuild/darwin-x64@0.27.4': + resolution: {integrity: sha512-sR+OiKLwd15nmCdqpXMnuJ9W2kpy0KigzqScqHI3Hqwr7IXxBp3Yva+yJwoqh7rE8V77tdoheRYataNKL4QrPw==} + engines: {node: '>=18'} + cpu: [x64] + os: [darwin] + + '@esbuild/freebsd-arm64@0.27.4': + resolution: {integrity: sha512-jnfpKe+p79tCnm4GVav68A7tUFeKQwQyLgESwEAUzyxk/TJr4QdGog9sqWNcUbr/bZt/O/HXouspuQDd9JxFSw==} + engines: {node: '>=18'} + cpu: [arm64] + os: [freebsd] + + '@esbuild/freebsd-x64@0.27.4': + resolution: {integrity: sha512-2kb4ceA/CpfUrIcTUl1wrP/9ad9Atrp5J94Lq69w7UwOMolPIGrfLSvAKJp0RTvkPPyn6CIWrNy13kyLikZRZQ==} + engines: {node: '>=18'} + cpu: [x64] + os: [freebsd] + + '@esbuild/linux-arm64@0.27.4': + resolution: {integrity: sha512-7nQOttdzVGth1iz57kxg9uCz57dxQLHWxopL6mYuYthohPKEK0vU0C3O21CcBK6KDlkYVcnDXY099HcCDXd9dA==} + engines: {node: '>=18'} + cpu: [arm64] + os: [linux] + + '@esbuild/linux-arm@0.27.4': + resolution: {integrity: sha512-aBYgcIxX/wd5n2ys0yESGeYMGF+pv6g0DhZr3G1ZG4jMfruU9Tl1i2Z+Wnj9/KjGz1lTLCcorqE2viePZqj4Eg==} + engines: {node: '>=18'} + cpu: [arm] + os: [linux] + + '@esbuild/linux-ia32@0.27.4': + resolution: {integrity: sha512-oPtixtAIzgvzYcKBQM/qZ3R+9TEUd1aNJQu0HhGyqtx6oS7qTpvjheIWBbes4+qu1bNlo2V4cbkISr8q6gRBFA==} + engines: {node: '>=18'} + cpu: [ia32] + os: [linux] + + '@esbuild/linux-loong64@0.27.4': + resolution: {integrity: sha512-8mL/vh8qeCoRcFH2nM8wm5uJP+ZcVYGGayMavi8GmRJjuI3g1v6Z7Ni0JJKAJW+m0EtUuARb6Lmp4hMjzCBWzA==} + engines: {node: '>=18'} + cpu: [loong64] + os: [linux] + + '@esbuild/linux-mips64el@0.27.4': + resolution: {integrity: sha512-1RdrWFFiiLIW7LQq9Q2NES+HiD4NyT8Itj9AUeCl0IVCA459WnPhREKgwrpaIfTOe+/2rdntisegiPWn/r/aAw==} + engines: {node: '>=18'} + cpu: [mips64el] + os: [linux] + + '@esbuild/linux-ppc64@0.27.4': + resolution: {integrity: sha512-tLCwNG47l3sd9lpfyx9LAGEGItCUeRCWeAx6x2Jmbav65nAwoPXfewtAdtbtit/pJFLUWOhpv0FpS6GQAmPrHA==} + engines: {node: '>=18'} + cpu: [ppc64] + os: [linux] + + '@esbuild/linux-riscv64@0.27.4': + resolution: {integrity: sha512-BnASypppbUWyqjd1KIpU4AUBiIhVr6YlHx/cnPgqEkNoVOhHg+YiSVxM1RLfiy4t9cAulbRGTNCKOcqHrEQLIw==} + engines: {node: '>=18'} + cpu: [riscv64] + os: [linux] + + '@esbuild/linux-s390x@0.27.4': + resolution: {integrity: sha512-+eUqgb/Z7vxVLezG8bVB9SfBie89gMueS+I0xYh2tJdw3vqA/0ImZJ2ROeWwVJN59ihBeZ7Tu92dF/5dy5FttA==} + engines: {node: '>=18'} + cpu: [s390x] + os: [linux] + + '@esbuild/linux-x64@0.27.4': + resolution: {integrity: sha512-S5qOXrKV8BQEzJPVxAwnryi2+Iq5pB40gTEIT69BQONqR7JH1EPIcQ/Uiv9mCnn05jff9umq/5nqzxlqTOg9NA==} + engines: {node: '>=18'} + cpu: [x64] + os: [linux] + + '@esbuild/netbsd-arm64@0.27.4': + resolution: {integrity: sha512-xHT8X4sb0GS8qTqiwzHqpY00C95DPAq7nAwX35Ie/s+LO9830hrMd3oX0ZMKLvy7vsonee73x0lmcdOVXFzd6Q==} + engines: {node: '>=18'} + cpu: [arm64] + os: [netbsd] + + '@esbuild/netbsd-x64@0.27.4': + resolution: {integrity: sha512-RugOvOdXfdyi5Tyv40kgQnI0byv66BFgAqjdgtAKqHoZTbTF2QqfQrFwa7cHEORJf6X2ht+l9ABLMP0dnKYsgg==} + engines: {node: '>=18'} + cpu: [x64] + os: [netbsd] + + '@esbuild/openbsd-arm64@0.27.4': + resolution: {integrity: sha512-2MyL3IAaTX+1/qP0O1SwskwcwCoOI4kV2IBX1xYnDDqthmq5ArrW94qSIKCAuRraMgPOmG0RDTA74mzYNQA9ow==} + engines: {node: '>=18'} + cpu: [arm64] + os: [openbsd] + + '@esbuild/openbsd-x64@0.27.4': + resolution: {integrity: sha512-u8fg/jQ5aQDfsnIV6+KwLOf1CmJnfu1ShpwqdwC0uA7ZPwFws55Ngc12vBdeUdnuWoQYx/SOQLGDcdlfXhYmXQ==} + engines: {node: '>=18'} + cpu: [x64] + os: [openbsd] + + '@esbuild/openharmony-arm64@0.27.4': + resolution: {integrity: sha512-JkTZrl6VbyO8lDQO3yv26nNr2RM2yZzNrNHEsj9bm6dOwwu9OYN28CjzZkH57bh4w0I2F7IodpQvUAEd1mbWXg==} + engines: {node: '>=18'} + cpu: [arm64] + os: [openharmony] + + '@esbuild/sunos-x64@0.27.4': + resolution: {integrity: sha512-/gOzgaewZJfeJTlsWhvUEmUG4tWEY2Spp5M20INYRg2ZKl9QPO3QEEgPeRtLjEWSW8FilRNacPOg8R1uaYkA6g==} + engines: {node: '>=18'} + cpu: [x64] + os: [sunos] + + '@esbuild/win32-arm64@0.27.4': + resolution: {integrity: sha512-Z9SExBg2y32smoDQdf1HRwHRt6vAHLXcxD2uGgO/v2jK7Y718Ix4ndsbNMU/+1Qiem9OiOdaqitioZwxivhXYg==} + engines: {node: '>=18'} + cpu: [arm64] + os: [win32] + + '@esbuild/win32-ia32@0.27.4': + resolution: {integrity: sha512-DAyGLS0Jz5G5iixEbMHi5KdiApqHBWMGzTtMiJ72ZOLhbu/bzxgAe8Ue8CTS3n3HbIUHQz/L51yMdGMeoxXNJw==} + engines: {node: '>=18'} + cpu: [ia32] + os: [win32] + + '@esbuild/win32-x64@0.27.4': + resolution: {integrity: sha512-+knoa0BDoeXgkNvvV1vvbZX4+hizelrkwmGJBdT17t8FNPwG2lKemmuMZlmaNQ3ws3DKKCxpb4zRZEIp3UxFCg==} + engines: {node: '>=18'} + cpu: [x64] + os: [win32] + + '@napi-rs/wasm-runtime@1.1.1': + resolution: {integrity: sha512-p64ah1M1ld8xjWv3qbvFwHiFVWrq1yFvV4f7w+mzaqiR4IlSgkqhcRdHwsGgomwzBH51sRY4NEowLxnaBjcW/A==} + + '@oxc-project/runtime@0.121.0': + resolution: {integrity: sha512-p0bQukD8OEHxzY4T9OlANBbEFGnOnjo1CYi50HES7OD36UO2yPh6T+uOJKLtlg06eclxroipRCpQGMpeH8EJ/g==} + engines: {node: ^20.19.0 || >=22.12.0} + + '@oxc-project/types@0.122.0': + resolution: {integrity: sha512-oLAl5kBpV4w69UtFZ9xqcmTi+GENWOcPF7FCrczTiBbmC0ibXxCwyvZGbO39rCVEuLGAZM84DH0pUIyyv/YJzA==} + + '@oxfmt/binding-android-arm-eabi@0.42.0': + resolution: {integrity: sha512-dsqPTYsozeokRjlrt/b4E7Pj0z3eS3Eg74TWQuuKbjY4VttBmA88rB7d50Xrd+TZ986qdXCNeZRPEzZHAe+jow==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [android] + + '@oxfmt/binding-android-arm64@0.42.0': + resolution: {integrity: sha512-t+aAjHxcr5eOBphFHdg1ouQU9qmZZoRxnX7UOJSaTwSoKsb6TYezNKO0YbWytGXCECObRqNcUxPoPr0KaraAIg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@oxfmt/binding-darwin-arm64@0.42.0': + resolution: {integrity: sha512-ulpSEYMKg61C5bRMZinFHrKJYRoKGVbvMEXA5zM1puX3O9T6Q4XXDbft20yrDijpYWeuG59z3Nabt+npeTsM1A==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@oxfmt/binding-darwin-x64@0.42.0': + resolution: {integrity: sha512-ttxLKhQYPdFiM8I/Ri37cvqChE4Xa562nNOsZFcv1CKTVLeEozXjKuYClNvxkXmNlcF55nzM80P+CQkdFBu+uQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@oxfmt/binding-freebsd-x64@0.42.0': + resolution: {integrity: sha512-Og7QS3yI3tdIKYZ58SXik0rADxIk2jmd+/YvuHRyKULWpG4V2fR5V4hvKm624Mc0cQET35waPXiCQWvjQEjwYQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@oxfmt/binding-linux-arm-gnueabihf@0.42.0': + resolution: {integrity: sha512-jwLOw/3CW4H6Vxcry4/buQHk7zm9Ne2YsidzTL1kpiMe4qqrRCwev3dkyWe2YkFmP+iZCQ7zku4KwjcLRoh8ew==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxfmt/binding-linux-arm-musleabihf@0.42.0': + resolution: {integrity: sha512-XwXu2vkMtiq2h7tfvN+WA/9/5/1IoGAVCFPiiQUvcAuG3efR97KNcRGM8BetmbYouFotQ2bDal3yyjUx6IPsTg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxfmt/binding-linux-arm64-gnu@0.42.0': + resolution: {integrity: sha512-ea7s/XUJoT7ENAtUQDudFe3nkSM3e3Qpz4nJFRdzO2wbgXEcjnchKLEsV3+t4ev3r8nWxIYr9NRjPWtnyIFJVA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-arm64-musl@0.42.0': + resolution: {integrity: sha512-+JA0YMlSdDqmacygGi2REp57c3fN+tzARD8nwsukx9pkCHK+6DkbAA9ojS4lNKsiBjIW8WWa0pBrBWhdZEqfuw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@oxfmt/binding-linux-ppc64-gnu@0.42.0': + resolution: {integrity: sha512-VfnET0j4Y5mdfCzh5gBt0NK28lgn5DKx+8WgSMLYYeSooHhohdbzwAStLki9pNuGy51y4I7IoW8bqwAaCMiJQg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-riscv64-gnu@0.42.0': + resolution: {integrity: sha512-gVlCbmBkB0fxBWbhBj9rcxezPydsQHf4MFKeHoTSPicOQ+8oGeTQgQ8EeesSybWeiFPVRx3bgdt4IJnH6nOjAA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-riscv64-musl@0.42.0': + resolution: {integrity: sha512-zN5OfstL0avgt/IgvRu0zjQzVh/EPkcLzs33E9LMAzpqlLWiPWeMDZyMGFlSRGOdDjuNmlZBCgj0pFnK5u32TQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [musl] + + '@oxfmt/binding-linux-s390x-gnu@0.42.0': + resolution: {integrity: sha512-9X6+H2L0qMc2sCAgO9HS03bkGLMKvOFjmEdchaFlany3vNZOjnVui//D8k/xZAtQv2vaCs1reD5KAgPoIU4msA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-x64-gnu@0.42.0': + resolution: {integrity: sha512-BajxJ6KQvMMdpXGPWhBGyjb2Jvx4uec0w+wi6TJZ6Tv7+MzPwe0pO8g5h1U0jyFgoaF7mDl6yKPW3ykWcbUJRw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@oxfmt/binding-linux-x64-musl@0.42.0': + resolution: {integrity: sha512-0wV284I6vc5f0AqAhgAbHU2935B4bVpncPoe5n/WzVZY/KnHgqxC8iSFGeSyLWEgstFboIcWkOPck7tqbdHkzA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@oxfmt/binding-openharmony-arm64@0.42.0': + resolution: {integrity: sha512-p4BG6HpGnhfgHk1rzZfyR6zcWkE7iLrWxyehHfXUy4Qa5j3e0roglFOdP/Nj5cJJ58MA3isQ5dlfkW2nNEpolw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@oxfmt/binding-win32-arm64-msvc@0.42.0': + resolution: {integrity: sha512-mn//WV60A+IetORDxYieYGAoQso4KnVRRjORDewMcod4irlRe0OSC7YPhhwaexYNPQz/GCFk+v9iUcZ2W22yxQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@oxfmt/binding-win32-ia32-msvc@0.42.0': + resolution: {integrity: sha512-3gWltUrvuz4LPJXWivoAxZ28Of2O4N7OGuM5/X3ubPXCEV8hmgECLZzjz7UYvSDUS3grfdccQwmjynm+51EFpw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ia32] + os: [win32] + + '@oxfmt/binding-win32-x64-msvc@0.42.0': + resolution: {integrity: sha512-Wg4TMAfQRL9J9AZevJ/ZNy3uyyDztDYQtGr4P8UyyzIhLhFrdSmz1J/9JT+rv0fiCDLaFOBQnj3f3K3+a5PzDQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + '@oxlint-tsgolint/darwin-arm64@0.17.3': + resolution: {integrity: sha512-5aDl4mxXWs+Bj02pNrX6YY6v9KMZjLIytXoqolLEo0dfBNVeZUonZgJAa/w0aUmijwIRrBhxEzb42oLuUtfkGw==} + cpu: [arm64] + os: [darwin] + + '@oxlint-tsgolint/darwin-x64@0.17.3': + resolution: {integrity: sha512-gPBy4DS5ueCgXzko20XsNZzDe/Cxde056B+QuPLGvz05CGEAtmRfpImwnyY2lAXXjPL+SmnC/OYexu8zI12yHQ==} + cpu: [x64] + os: [darwin] + + '@oxlint-tsgolint/linux-arm64@0.17.3': + resolution: {integrity: sha512-+pkunvCfB6pB0G9qHVVXUao3nqzXQPo4O3DReIi+5nGa+bOU3J3Srgy+Zb8VyOL+WDsSMJ+U7+r09cKHWhz3hg==} + cpu: [arm64] + os: [linux] + + '@oxlint-tsgolint/linux-x64@0.17.3': + resolution: {integrity: sha512-/kW5oXtBThu4FjmgIBthdmMjWLzT3M1TEDQhxDu7hQU5xDeTd60CDXb2SSwKCbue9xu7MbiFoJu83LN0Z/d38g==} + cpu: [x64] + os: [linux] + + '@oxlint-tsgolint/win32-arm64@0.17.3': + resolution: {integrity: sha512-NMELRvbz4Ed4dxg8WiqZxtu3k4OJEp2B9KInZW+BMfqEqbwZdEJY83tbqz2hD1EjKO2akrqBQ0GpRUJEkd8kKw==} + cpu: [arm64] + os: [win32] + + '@oxlint-tsgolint/win32-x64@0.17.3': + resolution: {integrity: sha512-+pJ7r8J3SLPws5uoidVplZc8R/lpKyKPE6LoPGv9BME00Y1VjT6jWGx/dtUN8PWvcu3iTC6k+8u3ojFSJNmWTg==} + cpu: [x64] + os: [win32] + + '@oxlint/binding-android-arm-eabi@1.57.0': + resolution: {integrity: sha512-C7EiyfAJG4B70496eV543nKiq5cH0o/xIh/ufbjQz3SIvHhlDDsyn+mRFh+aW8KskTyUpyH2LGWL8p2oN6bl1A==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [android] + + '@oxlint/binding-android-arm64@1.57.0': + resolution: {integrity: sha512-9i80AresjZ/FZf5xK8tKFbhQnijD4s1eOZw6/FHUwD59HEZbVLRc2C88ADYJfLZrF5XofWDiRX/Ja9KefCLy7w==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@oxlint/binding-darwin-arm64@1.57.0': + resolution: {integrity: sha512-0eUfhRz5L2yKa9I8k3qpyl37XK3oBS5BvrgdVIx599WZK63P8sMbg+0s4IuxmIiZuBK68Ek+Z+gcKgeYf0otsg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@oxlint/binding-darwin-x64@1.57.0': + resolution: {integrity: sha512-UvrSuzBaYOue+QMAcuDITe0k/Vhj6KZGjfnI6x+NkxBTke/VoM7ZisaxgNY0LWuBkTnd1OmeQfEQdQ48fRjkQg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@oxlint/binding-freebsd-x64@1.57.0': + resolution: {integrity: sha512-wtQq0dCoiw4bUwlsNVDJJ3pxJA218fOezpgtLKrbQqUtQJcM9yP8z+I9fu14aHg0uyAxIY+99toL6uBa2r7nxA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@oxlint/binding-linux-arm-gnueabihf@1.57.0': + resolution: {integrity: sha512-qxFWl2BBBFcT4djKa+OtMdnLgoHEJXpqjyGwz8OhW35ImoCwR5qtAGqApNYce5260FQqoAHW8S8eZTjiX67Tsg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxlint/binding-linux-arm-musleabihf@1.57.0': + resolution: {integrity: sha512-SQoIsBU7J0bDW15/f0/RvxHfY3Y0+eB/caKBQtNFbuerTiA6JCYx9P1MrrFTwY2dTm/lMgTSgskvCEYk2AtG/Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@oxlint/binding-linux-arm64-gnu@1.57.0': + resolution: {integrity: sha512-jqxYd1W6WMeozsCmqe9Rzbu3SRrGTyGDAipRlRggetyYbUksJqJKvUNTQtZR/KFoJPb+grnSm5SHhdWrywv3RQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-arm64-musl@1.57.0': + resolution: {integrity: sha512-i66WyEPVEvq9bxRUCJ/MP5EBfnTDN3nhwEdFZFTO5MmLLvzngfWEG3NSdXQzTT3vk5B9i6C2XSIYBh+aG6uqyg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@oxlint/binding-linux-ppc64-gnu@1.57.0': + resolution: {integrity: sha512-oMZDCwz4NobclZU3pH+V1/upVlJZiZvne4jQP+zhJwt+lmio4XXr4qG47CehvrW1Lx2YZiIHuxM2D4YpkG3KVA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-riscv64-gnu@1.57.0': + resolution: {integrity: sha512-uoBnjJ3MMEBbfnWC1jSFr7/nSCkcQYa72NYoNtLl1imshDnWSolYCjzb8LVCwYCCfLJXD+0gBLD7fyC14c0+0g==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-riscv64-musl@1.57.0': + resolution: {integrity: sha512-BdrwD7haPZ8a9KrZhKJRSj6jwCor+Z8tHFZ3PT89Y3Jq5v3LfMfEePeAmD0LOTWpiTmzSzdmyw9ijneapiVHKQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [riscv64] + os: [linux] + libc: [musl] + + '@oxlint/binding-linux-s390x-gnu@1.57.0': + resolution: {integrity: sha512-BNs+7ZNsRstVg2tpNxAXfMX/Iv5oZh204dVyb8Z37+/gCh+yZqNTlg6YwCLIMPSk5wLWIGOaQjT0GUOahKYImw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-x64-gnu@1.57.0': + resolution: {integrity: sha512-AghS18w+XcENcAX0+BQGLiqjpqpaxKJa4cWWP0OWNLacs27vHBxu7TYkv9LUSGe5w8lOJHeMxcYfZNOAPqw2bg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@oxlint/binding-linux-x64-musl@1.57.0': + resolution: {integrity: sha512-E/FV3GB8phu/Rpkhz5T96hAiJlGzn91qX5yj5gU754P5cmVGXY1Jw/VSjDSlZBCY3VHjsVLdzgdkJaomEmcNOg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@oxlint/binding-openharmony-arm64@1.57.0': + resolution: {integrity: sha512-xvZ2yZt0nUVfU14iuGv3V25jpr9pov5N0Wr28RXnHFxHCRxNDMtYPHV61gGLhN9IlXM96gI4pyYpLSJC5ClLCQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@oxlint/binding-win32-arm64-msvc@1.57.0': + resolution: {integrity: sha512-Z4D8Pd0AyHBKeazhdIXeUUy5sIS3Mo0veOlzlDECg6PhRRKgEsBJCCV1n+keUZtQ04OP+i7+itS3kOykUyNhDg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@oxlint/binding-win32-ia32-msvc@1.57.0': + resolution: {integrity: sha512-StOZ9nFMVKvevicbQfql6Pouu9pgbeQnu60Fvhz2S6yfMaii+wnueLnqQ5I1JPgNF0Syew4voBlAaHD13wH6tw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ia32] + os: [win32] + + '@oxlint/binding-win32-x64-msvc@1.57.0': + resolution: {integrity: sha512-6PuxhYgth8TuW0+ABPOIkGdBYw+qYGxgIdXPHSVpiCDm+hqTTWCmC739St1Xni0DJBt8HnSHTG67i1y6gr8qrA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + '@playwright/test@1.51.1': + resolution: {integrity: sha512-nM+kEaTSAoVlXmMPH10017vn3FSiFqr/bh4fKg9vmAdMfd9SDqRZNvPSiAHADc/itWak+qPvMPZQOPwCBW7k7Q==} + engines: {node: '>=18'} + hasBin: true + + '@polka/url@1.0.0-next.29': + resolution: {integrity: sha512-wwQAWhWSuHaag8c4q/KN/vCoeOJYshAIvMQwD4GpSb3OiZklFfvAgmj0VCBBImRpuF/aFgIRzllXlVX93Jevww==} + + '@rolldown/binding-android-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-pv1y2Fv0JybcykuiiD3qBOBdz6RteYojRFY1d+b95WVuzx211CRh+ytI/+9iVyWQ6koTh5dawe4S/yRfOFjgaA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [android] + + '@rolldown/binding-darwin-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-cFYr6zTG/3PXXF3pUO+umXxt1wkRK/0AYT8lDwuqvRC+LuKYWSAQAQZjCWDQpAH172ZV6ieYrNnFzVVcnSflAg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@rolldown/binding-darwin-x64@1.0.0-rc.12': + resolution: {integrity: sha512-ZCsYknnHzeXYps0lGBz8JrF37GpE9bFVefrlmDrAQhOEi4IOIlcoU1+FwHEtyXGx2VkYAvhu7dyBf75EJQffBw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@rolldown/binding-freebsd-x64@1.0.0-rc.12': + resolution: {integrity: sha512-dMLeprcVsyJsKolRXyoTH3NL6qtsT0Y2xeuEA8WQJquWFXkEC4bcu1rLZZSnZRMtAqwtrF/Ib9Ddtpa/Gkge9Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [freebsd] + + '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.12': + resolution: {integrity: sha512-YqWjAgGC/9M1lz3GR1r1rP79nMgo3mQiiA+Hfo+pvKFK1fAJ1bCi0ZQVh8noOqNacuY1qIcfyVfP6HoyBRZ85Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm] + os: [linux] + + '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-/I5AS4cIroLpslsmzXfwbe5OmWvSsrFuEw3mwvbQ1kDxJ822hFHIx+vsN/TAzNVyepI/j/GSzrtCIwQPeKCLIg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-arm64-musl@1.0.0-rc.12': + resolution: {integrity: sha512-V6/wZztnBqlx5hJQqNWwFdxIKN0m38p8Jas+VoSfgH54HSj9tKTt1dZvG6JRHcjh6D7TvrJPWFGaY9UBVOaWPw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-AP3E9BpcUYliZCxa3w5Kwj9OtEVDYK6sVoUzy4vTOJsjPOgdaJZKFmN4oOlX0Wp0RPV2ETfmIra9x1xuayFB7g==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [ppc64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-s390x-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-nWwpvUSPkoFmZo0kQazZYOrT7J5DGOJ/+QHHzjvNlooDZED8oH82Yg67HvehPPLAg5fUff7TfWFHQS8IV1n3og==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [s390x] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-x64-gnu@1.0.0-rc.12': + resolution: {integrity: sha512-RNrafz5bcwRy+O9e6P8Z/OCAJW/A+qtBczIqVYwTs14pf4iV1/+eKEjdOUta93q2TsT/FI0XYDP3TCky38LMAg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@rolldown/binding-linux-x64-musl@1.0.0-rc.12': + resolution: {integrity: sha512-Jpw/0iwoKWx3LJ2rc1yjFrj+T7iHZn2JDg1Yny1ma0luviFS4mhAIcd1LFNxK3EYu3DHWCps0ydXQ5i/rrJ2ig==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@rolldown/binding-openharmony-arm64@1.0.0-rc.12': + resolution: {integrity: sha512-vRugONE4yMfVn0+7lUKdKvN4D5YusEiPilaoO2sgUWpCvrncvWgPMzK00ZFFJuiPgLwgFNP5eSiUlv2tfc+lpA==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [openharmony] + + '@rolldown/binding-wasm32-wasi@1.0.0-rc.12': + resolution: {integrity: sha512-ykGiLr/6kkiHc0XnBfmFJuCjr5ZYKKofkx+chJWDjitX+KsJuAmrzWhwyOMSHzPhzOHOy7u9HlFoa5MoAOJ/Zg==} + engines: {node: '>=14.0.0'} + cpu: [wasm32] + + '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.12': + resolution: {integrity: sha512-5eOND4duWkwx1AzCxadcOrNeighiLwMInEADT0YM7xeEOOFcovWZCq8dadXgcRHSf3Ulh1kFo/qvzoFiCLOL1Q==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@rolldown/binding-win32-x64-msvc@1.0.0-rc.12': + resolution: {integrity: sha512-PyqoipaswDLAZtot351MLhrlrh6lcZPo2LSYE+VDxbVk24LVKAGOuE4hb8xZQmrPAuEtTZW8E6D2zc5EUZX4Lw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + '@rolldown/pluginutils@1.0.0-rc.12': + resolution: {integrity: sha512-HHMwmarRKvoFsJorqYlFeFRzXZqCt2ETQlEDOb9aqssrnVBB1/+xgTGtuTrIk5vzLNX1MjMtTf7W9z3tsSbrxw==} + + '@standard-schema/spec@1.1.0': + resolution: {integrity: sha512-l2aFy5jALhniG5HgqrD6jXLi/rUWrKvqN/qJx6yoJsgKhblVd+iqqU4RCXavm/jPityDo5TCvKMnpjKnOriy0w==} + + '@teppeis/multimaps@3.0.0': + resolution: {integrity: sha512-ID7fosbc50TbT0MK0EG12O+gAP3W3Aa/Pz4DaTtQtEvlc9Odaqi0de+xuZ7Li2GtK4HzEX7IuRWS/JmZLksR3Q==} + engines: {node: '>=14'} + + '@tybys/wasm-util@0.10.1': + resolution: {integrity: sha512-9tTaPJLSiejZKx+Bmog4uSubteqTvFrVrURwkmHixBo0G4seD0zUxp98E1DzUBJxLQ3NPwXrGKDiVjwx/DpPsg==} + + '@types/chai@5.2.3': + resolution: {integrity: sha512-Mw558oeA9fFbv65/y4mHtXDs9bPnFMZAL/jxdPFUpOHHIXX91mcgEHbS5Lahr+pwZFR8A7GQleRWeI6cGFC2UA==} + + '@types/deep-eql@4.0.2': + resolution: {integrity: sha512-c9h9dVVMigMPc4bwTvC5dxqtqJZwQPePsWjPlpSOnojbor6pGqdk541lfA7AqFQr5pB1BRdq0juY9db81BwyFw==} + + '@types/node@25.5.0': + resolution: {integrity: sha512-jp2P3tQMSxWugkCUKLRPVUpGaL5MVFwF8RDuSRztfwgN1wmqJeMSbKlnEtQqU8UrhTmzEmZdu2I6v2dpp7XIxw==} + + '@types/normalize-package-data@2.4.4': + resolution: {integrity: sha512-37i+OaWTh9qeK4LSHPsyRC7NahnGotNuZvjLSgcPzblpHB3rrCJxAOgI5gCdKm7coonsaX1Of0ILiTcnZjbfxA==} + + '@voidzero-dev/vite-plus-core@0.1.14': + resolution: {integrity: sha512-CCWzdkfW0fo0cQNlIsYp5fOuH2IwKuPZEb2UY2Z8gXcp5pG74A82H2Pthj0heAuvYTAnfT7kEC6zM+RbiBgQbg==} + engines: {node: ^20.19.0 || >=22.12.0} + peerDependencies: + '@arethetypeswrong/core': ^0.18.1 + '@tsdown/css': 0.21.4 + '@tsdown/exe': 0.21.4 + '@types/node': ^20.19.0 || >=22.12.0 + '@vitejs/devtools': ^0.1.0 + esbuild: ^0.27.0 + jiti: '>=1.21.0' + less: ^4.0.0 + publint: ^0.3.0 + sass: ^1.70.0 + sass-embedded: ^1.70.0 + stylus: '>=0.54.8' + sugarss: ^5.0.0 + terser: ^5.16.0 + tsx: ^4.8.1 + typescript: ^5.0.0 + unplugin-unused: ^0.5.0 + yaml: ^2.4.2 + peerDependenciesMeta: + '@arethetypeswrong/core': + optional: true + '@tsdown/css': + optional: true + '@tsdown/exe': + optional: true + '@types/node': + optional: true + '@vitejs/devtools': + optional: true + esbuild: + optional: true + jiti: + optional: true + less: + optional: true + publint: + optional: true + sass: + optional: true + sass-embedded: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + tsx: + optional: true + typescript: + optional: true + unplugin-unused: + optional: true + yaml: + optional: true + + '@voidzero-dev/vite-plus-darwin-arm64@0.1.14': + resolution: {integrity: sha512-q2ESUSbapwsxVRe/KevKATahNRraoX5nti3HT9S3266OHT5sMroBY14jaxTv74ekjQc9E6EPhyLGQWuWQuuBRw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [darwin] + + '@voidzero-dev/vite-plus-darwin-x64@0.1.14': + resolution: {integrity: sha512-UpcDZc9G99E/4HDRoobvYHxMvFOG5uv3RwEcq0HF70u4DsnEMl1z8RaJLeWV7a09LGwj9Q+YWC3Z4INWnTLs8g==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [darwin] + + '@voidzero-dev/vite-plus-linux-arm64-gnu@0.1.14': + resolution: {integrity: sha512-GIjn35RABUEDB9gHD26nRq7T72Te+Qy2+NIzogwEaUE728PvPkatF5gMCeF4sigCoc8c4qxDwsG+A2A2LYGnDg==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [glibc] + + '@voidzero-dev/vite-plus-linux-arm64-musl@0.1.14': + resolution: {integrity: sha512-qo2RToGirG0XCcxZ2AEOuonLM256z6dNbJzDDIo5gWYA+cIKigFQJbkPyr25zsT1tsP2aY0OTxt2038XbVlRkQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [linux] + libc: [musl] + + '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.14': + resolution: {integrity: sha512-BsMWKZfdfGcYLxxLyaePpg6NW54xqzzcfq8sFUwKfwby0kgOKQ4WymUXyBvO9nnBb0ZPsJQrV0sx+Onac/LTaw==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [glibc] + + '@voidzero-dev/vite-plus-linux-x64-musl@0.1.14': + resolution: {integrity: sha512-mOrEpj7ntW9RopGbcOYG/L0pOs0qHzUG4Vz7NXbuf4dbOSlY4JjyoMOIWxjKQORQht02Hzuf8YrMGNwa6AjVSQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [linux] + libc: [musl] + + '@voidzero-dev/vite-plus-test@0.1.14': + resolution: {integrity: sha512-rjF+qpYD+5+THOJZ3gbE3+cxsk5sW7nJ0ODK7y6ZKeS4amREUMedEDYykzKBwR7OZDC/WwE90A0iLWCr6qAXhA==} + engines: {node: ^20.0.0 || ^22.0.0 || >=24.0.0} + peerDependencies: + '@edge-runtime/vm': '*' + '@opentelemetry/api': ^1.9.0 + '@types/node': ^20.0.0 || ^22.0.0 || >=24.0.0 + '@vitest/ui': 4.1.1 + happy-dom: '*' + jsdom: '*' + vite: ^6.0.0 || ^7.0.0 || ^8.0.0 + peerDependenciesMeta: + '@edge-runtime/vm': + optional: true + '@opentelemetry/api': + optional: true + '@types/node': + optional: true + '@vitest/ui': + optional: true + happy-dom: + optional: true + jsdom: + optional: true + + '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.14': + resolution: {integrity: sha512-7iC+Ig+8D/zACy0IJf7w/vQ7duTjux9Ttmm3KOBdVWH4dl3JihydA7+SQVMhz71a4WiqJ6nPidoG8D6hUP4MVQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [arm64] + os: [win32] + + '@voidzero-dev/vite-plus-win32-x64-msvc@0.1.14': + resolution: {integrity: sha512-yRJ/8yAYFluNHx0Ej6Kevx65MIeM3wFKklnxosVZRlz2ZRL1Ea1Qh3tWATr3Ipk1ciRxBv8KJgp6zXqjxtZSoQ==} + engines: {node: ^20.19.0 || >=22.12.0} + cpu: [x64] + os: [win32] + + ansi-regex@4.1.1: + resolution: {integrity: sha512-ILlv4k/3f6vfQ4OoP2AGvirOktlQ98ZEL1k9FaQjxa3L1abBgbuTDAdPOpvbGncC0BTVQrl+OM8xZGK6tWXt7g==} + engines: {node: '>=6'} + + ansi-regex@5.0.1: + resolution: {integrity: sha512-quJQXlTSUGL2LH9SUXo8VwsY4soanhgo6LNSm84E1LBcE8s3O0wpdiRzyR9z/ZZJMlMWv37qOOb9pdJlMUEKFQ==} + engines: {node: '>=8'} + + ansi-styles@4.3.0: + resolution: {integrity: sha512-zbB9rCJAT1rbjiVDb2hqKFHNYLxgtk8NURxZ3IZwD3F6NtxbXZQCnnSi1Lkx+IDohdPlFp222wVALIheZJQSEg==} + engines: {node: '>=8'} + + ansi-styles@5.2.0: + resolution: {integrity: sha512-Cxwpt2SfTzTtXcfOlzGEee8O+c+MmUgGrNiBcXnuWxuFJHe6a5Hz7qwhwe5OgaSYI0IJvkLqWX1ASG+cJOkEiA==} + engines: {node: '>=10'} + + any-promise@1.3.0: + resolution: {integrity: sha512-7UvmKalWRt1wgjL1RrGxoSJW/0QZFIegpeGvZG9kjp8vrRu55XTHbwnqq2GpXm9uLbcuhxm3IqX9OB4MZR1b2A==} + + assertion-error-formatter@3.0.0: + resolution: {integrity: sha512-6YyAVLrEze0kQ7CmJfUgrLHb+Y7XghmL2Ie7ijVa2Y9ynP3LV+VDiwFk62Dn0qtqbmY0BT0ss6p1xxpiF2PYbQ==} + + assertion-error@2.0.1: + resolution: {integrity: sha512-Izi8RQcffqCeNVgFigKli1ssklIbpHnCYc6AknXGYoB6grJqyeby7jv12JUQgmTAnIDnbck1uxksT4dzN3PWBA==} + engines: {node: '>=12'} + + balanced-match@4.0.4: + resolution: {integrity: sha512-BLrgEcRTwX2o6gGxGOCNyMvGSp35YofuYzw9h1IMTRmKqttAZZVU67bdb9Pr2vUHA8+j3i2tJfjO6C6+4myGTA==} + engines: {node: 18 || 20 || >=22} + + brace-expansion@5.0.5: + resolution: {integrity: sha512-VZznLgtwhn+Mact9tfiwx64fA9erHH/MCXEUfB/0bX/6Fz6ny5EGTXYltMocqg4xFAQZtnO3DHWWXi8RiuN7cQ==} + engines: {node: 18 || 20 || >=22} + + buffer-from@1.1.2: + resolution: {integrity: sha512-E+XQCRwSbaaiChtv6k6Dwgc+bx+Bs6vuKJHHl5kox/BaKbhiXzqQOwK4cO22yElGp2OCmjwVhT3HmxgyPGnJfQ==} + + cac@7.0.0: + resolution: {integrity: sha512-tixWYgm5ZoOD+3g6UTea91eow5z6AAHaho3g0V9CNSNb45gM8SmflpAc+GRd1InC4AqN/07Unrgp56Y94N9hJQ==} + engines: {node: '>=20.19.0'} + + capital-case@1.0.4: + resolution: {integrity: sha512-ds37W8CytHgwnhGGTi88pcPyR15qoNkOpYwmMMfnWqqWgESapLqvDx6huFjQ5vqWSn2Z06173XNA7LtMOeUh1A==} + + chalk@4.1.2: + resolution: {integrity: sha512-oKnbhFyRIXpUuez8iBMmyEa4nbj4IOQyuhc/wy9kY7/WVPcwIO9VA668Pu8RkO7+0G76SLROeyw9CpQ061i4mA==} + engines: {node: '>=10'} + + class-transformer@0.5.1: + resolution: {integrity: sha512-SQa1Ws6hUbfC98vKGxZH3KFY0Y1lm5Zm0SY8XX9zbK7FJCyVEac3ATW0RIpwzW+oOfmHE5PMPufDG9hCfoEOMw==} + + cli-table3@0.6.5: + resolution: {integrity: sha512-+W/5efTR7y5HRD7gACw9yQjqMVvEMLBHmboM/kPWam+H+Hmyrgjh6YncVKK122YZkXrLudzTuAukUw9FnMf7IQ==} + engines: {node: 10.* || >= 12.*} + + color-convert@2.0.1: + resolution: {integrity: sha512-RRECPsj7iu/xb5oKYcsFHSppFNnsj/52OVTRKb4zP5onXwVF3zVmmToNcOfGC+CRDpfK/U584fMg38ZHCaElKQ==} + engines: {node: '>=7.0.0'} + + color-name@1.1.4: + resolution: {integrity: sha512-dOy+3AuW3a2wNbZHIuMZpTcgjGuLU/uBL/ubcZF9OXbDo8ff4O8yVp5Bf0efS8uEoYo5q4Fx7dY9OgQGXgAsQA==} + + commander@14.0.0: + resolution: {integrity: sha512-2uM9rYjPvyq39NwLRqaiLtWHyDC1FvryJDa2ATTVims5YAS4PupsEQsDvP14FqhFr0P49CYDugi59xaxJlTXRA==} + engines: {node: '>=20'} + + commander@14.0.2: + resolution: {integrity: sha512-TywoWNNRbhoD0BXs1P3ZEScW8W5iKrnbithIl0YH+uCmBd0QpPOA8yc82DS3BIE5Ma6FnBVUsJ7wVUDz4dvOWQ==} + engines: {node: '>=20'} + + commander@14.0.3: + resolution: {integrity: sha512-H+y0Jo/T1RZ9qPP4Eh1pkcQcLRglraJaSLoyOtHxu6AapkjWVCy2Sit1QQ4x3Dng8qDlSsZEet7g5Pq06MvTgw==} + engines: {node: '>=20'} + + cross-spawn@7.0.6: + resolution: {integrity: sha512-uV2QOWP2nWzsy2aMp8aRibhi9dlzF5Hgh5SHaB9OiTGEyDTiJJyx0uy51QXdyWbtAHNua4XJzUKca3OzKUd3vA==} + engines: {node: '>= 8'} + + debug@4.4.3: + resolution: {integrity: sha512-RGwwWnwQvkVfavKVt22FGLw+xYSdzARwm0ru6DhTVA3umU5hZc28V3kO4stgYryrTlLpuvgI9GiijltAjNbcqA==} + engines: {node: '>=6.0'} + peerDependencies: + supports-color: '*' + peerDependenciesMeta: + supports-color: + optional: true + + detect-libc@2.1.2: + resolution: {integrity: sha512-Btj2BOOO83o3WyH59e8MgXsxEQVcarkUOpEYrubB0urwnN10yQ364rsiByU11nZlqWYZm05i/of7io4mzihBtQ==} + engines: {node: '>=8'} + + diff@4.0.4: + resolution: {integrity: sha512-X07nttJQkwkfKfvTPG/KSnE2OMdcUCao6+eXF3wmnIQRn2aPAHH3VxDbDOdegkd6JbPsXqShpvEOHfAT+nCNwQ==} + engines: {node: '>=0.3.1'} + + emoji-regex@8.0.0: + resolution: {integrity: sha512-MSjYzcWNOA0ewAHpz0MxpYFvwg6yjy1NG3xteoqz644VCo/RPgnr1/GGt+ic3iJTzQ8Eu3TdM14SawnVUmGE6A==} + + error-stack-parser@2.1.4: + resolution: {integrity: sha512-Sk5V6wVazPhq5MhpO+AUxJn5x7XSXGl1R93Vn7i+zS15KDVxQijejNCrz8340/2bgLBjR9GtEG8ZVKONDjcqGQ==} + + es-module-lexer@1.7.0: + resolution: {integrity: sha512-jEQoCwk8hyb2AZziIOLhDqpm5+2ww5uIE6lkO/6jcOCusfk6LhMHpXXfBLXTZ7Ydyt0j4VoUQv6uGNYbdW+kBA==} + + esbuild@0.27.4: + resolution: {integrity: sha512-Rq4vbHnYkK5fws5NF7MYTU68FPRE1ajX7heQ/8QXXWqNgqqJ/GkmmyxIzUnf2Sr/bakf8l54716CcMGHYhMrrQ==} + engines: {node: '>=18'} + hasBin: true + + escape-string-regexp@1.0.5: + resolution: {integrity: sha512-vbRorB5FUQWvla16U8R/qgaFIya2qGzwDrNmCZuYKrbdSUMG6I1ZCGQRefkRVhuOkIGVne7BQ35DSfo1qvJqFg==} + engines: {node: '>=0.8.0'} + + fdir@6.5.0: + resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==} + engines: {node: '>=12.0.0'} + peerDependencies: + picomatch: ^3 || ^4 + peerDependenciesMeta: + picomatch: + optional: true + + figures@3.2.0: + resolution: {integrity: sha512-yaduQFRKLXYOGgEn6AZau90j3ggSOyiqXU0F9JZfeXYhNa+Jk4X+s45A2zg5jns87GAFa34BBm2kXw4XpNcbdg==} + engines: {node: '>=8'} + + find-up-simple@1.0.1: + resolution: {integrity: sha512-afd4O7zpqHeRyg4PfDQsXmlDe2PfdHtJt6Akt8jOWaApLOZk5JXs6VMR29lz03pRe9mpykrRCYIYxaJYcfpncQ==} + engines: {node: '>=18'} + + fsevents@2.3.2: + resolution: {integrity: sha512-xiqMQR4xAeHTuB9uWm+fFRcIOgKBMiOBP+eXiyT7jsgVCq1bkVygt00oASowB7EdtpOHaaPgKt812P9ab+DDKA==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + + fsevents@2.3.3: + resolution: {integrity: sha512-5xoDfX+fL7faATnagmWPpbFtwh/R77WmMMqqHGS65C3vvB0YHrgF+B1YmZ3441tMj5n63k0212XNoJwzlhffQw==} + engines: {node: ^8.16.0 || ^10.6.0 || >=11.0.0} + os: [darwin] + + get-tsconfig@4.13.7: + resolution: {integrity: sha512-7tN6rFgBlMgpBML5j8typ92BKFi2sFQvIdpAqLA2beia5avZDrMs0FLZiM5etShWq5irVyGcGMEA1jcDaK7A/Q==} + + glob@13.0.6: + resolution: {integrity: sha512-Wjlyrolmm8uDpm/ogGyXZXb1Z+Ca2B8NbJwqBVg0axK9GbBeoS7yGV6vjXnYdGm6X53iehEuxxbyiKp8QmN4Vw==} + engines: {node: 18 || 20 || >=22} + + global-dirs@3.0.1: + resolution: {integrity: sha512-NBcGGFbBA9s1VzD41QXDG+3++t9Mn5t1FpLdhESY6oKY4gYTFpX4wO3sqGUa0Srjtbfj3szX0RnemmrVRUdULA==} + engines: {node: '>=10'} + + has-ansi@4.0.1: + resolution: {integrity: sha512-Qr4RtTm30xvEdqUXbSBVWDu+PrTokJOwe/FU+VdfJPk+MXAPoeOzKpRyrDTnZIJwAkQ4oBLTU53nu0HrkF/Z2A==} + engines: {node: '>=8'} + + has-flag@4.0.0: + resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} + engines: {node: '>=8'} + + hosted-git-info@9.0.2: + resolution: {integrity: sha512-M422h7o/BR3rmCQ8UHi7cyyMqKltdP9Uo+J2fXK+RSAY+wTcKOIRyhTuKv4qn+DJf3g+PL890AzId5KZpX+CBg==} + engines: {node: ^20.17.0 || >=22.9.0} + + indent-string@4.0.0: + resolution: {integrity: sha512-EdDDZu4A2OyIK7Lr/2zG+w5jmbuk1DVBnEwREQvBzspBJkCEbRa8GxU1lghYcaGJCnRWibjDXlq779X1/y5xwg==} + engines: {node: '>=8'} + + index-to-position@1.2.0: + resolution: {integrity: sha512-Yg7+ztRkqslMAS2iFaU+Oa4KTSidr63OsFGlOrJoW981kIYO3CGCS3wA95P1mUi/IVSJkn0D479KTJpVpvFNuw==} + engines: {node: '>=18'} + + ini@2.0.0: + resolution: {integrity: sha512-7PnF4oN3CvZF23ADhA5wRaYEQpJ8qygSkbtTXWBeXWXmEVRXK+1ITciHWwHhsjv1TmW0MgacIv6hEi5pX5NQdA==} + engines: {node: '>=10'} + + is-fullwidth-code-point@3.0.0: + resolution: {integrity: sha512-zymm5+u+sCsSWyD9qNaejV3DFvhCKclKdizYaJUuHA83RLjb7nSuGnddCHGv0hk+KY7BMAlsWeK4Ueg6EV6XQg==} + engines: {node: '>=8'} + + is-installed-globally@0.4.0: + resolution: {integrity: sha512-iwGqO3J21aaSkC7jWnHP/difazwS7SFeIqxv6wEtLU8Y5KlzFTjyqcSIT0d8s4+dDhKytsk9PJZ2BkS5eZwQRQ==} + engines: {node: '>=10'} + + is-path-inside@3.0.3: + resolution: {integrity: sha512-Fd4gABb+ycGAmKou8eMftCupSir5lRxqf4aD/vd0cD2qc4HL07OjCeuHMr8Ro4CoMaeCKDB0/ECBOVWjTwUvPQ==} + engines: {node: '>=8'} + + is-stream@2.0.1: + resolution: {integrity: sha512-hFoiJiTl63nn+kstHGBtewWSKnQLpyb155KHheA1l39uvtO9nWIop1p3udqPcUd/xbF1VLMO4n7OI6p7RbngDg==} + engines: {node: '>=8'} + + isexe@2.0.0: + resolution: {integrity: sha512-RHxMLp9lnKHGHRng9QFhRCMbYAcVpn69smSGcq3f36xjgVVWThj4qqLbTLlq7Ssj8B+fIQ1EuCEGI2lKsyQeIw==} + + js-tokens@4.0.0: + resolution: {integrity: sha512-RdJUflcE3cUzKiMqQgsCu06FPu9UdIJO0beYbPhHN4k6apgJtifcoCtT9bcxOpYBtpD2kCM6Sbzg4CausW/PKQ==} + + knuth-shuffle-seeded@1.0.6: + resolution: {integrity: sha512-9pFH0SplrfyKyojCLxZfMcvkhf5hH0d+UwR9nTVJ/DDQJGuzcXjTwB7TP7sDfehSudlGGaOLblmEWqv04ERVWg==} + + lightningcss-android-arm64@1.32.0: + resolution: {integrity: sha512-YK7/ClTt4kAK0vo6w3X+Pnm0D2cf2vPHbhOXdoNti1Ga0al1P4TBZhwjATvjNwLEBCnKvjJc2jQgHXH0NEwlAg==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [android] + + lightningcss-darwin-arm64@1.32.0: + resolution: {integrity: sha512-RzeG9Ju5bag2Bv1/lwlVJvBE3q6TtXskdZLLCyfg5pt+HLz9BqlICO7LZM7VHNTTn/5PRhHFBSjk5lc4cmscPQ==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [darwin] + + lightningcss-darwin-x64@1.32.0: + resolution: {integrity: sha512-U+QsBp2m/s2wqpUYT/6wnlagdZbtZdndSmut/NJqlCcMLTWp5muCrID+K5UJ6jqD2BFshejCYXniPDbNh73V8w==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [darwin] + + lightningcss-freebsd-x64@1.32.0: + resolution: {integrity: sha512-JCTigedEksZk3tHTTthnMdVfGf61Fky8Ji2E4YjUTEQX14xiy/lTzXnu1vwiZe3bYe0q+SpsSH/CTeDXK6WHig==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [freebsd] + + lightningcss-linux-arm-gnueabihf@1.32.0: + resolution: {integrity: sha512-x6rnnpRa2GL0zQOkt6rts3YDPzduLpWvwAF6EMhXFVZXD4tPrBkEFqzGowzCsIWsPjqSK+tyNEODUBXeeVHSkw==} + engines: {node: '>= 12.0.0'} + cpu: [arm] + os: [linux] + + lightningcss-linux-arm64-gnu@1.32.0: + resolution: {integrity: sha512-0nnMyoyOLRJXfbMOilaSRcLH3Jw5z9HDNGfT/gwCPgaDjnx0i8w7vBzFLFR1f6CMLKF8gVbebmkUN3fa/kQJpQ==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [linux] + libc: [glibc] + + lightningcss-linux-arm64-musl@1.32.0: + resolution: {integrity: sha512-UpQkoenr4UJEzgVIYpI80lDFvRmPVg6oqboNHfoH4CQIfNA+HOrZ7Mo7KZP02dC6LjghPQJeBsvXhJod/wnIBg==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [linux] + libc: [musl] + + lightningcss-linux-x64-gnu@1.32.0: + resolution: {integrity: sha512-V7Qr52IhZmdKPVr+Vtw8o+WLsQJYCTd8loIfpDaMRWGUZfBOYEJeyJIkqGIDMZPwPx24pUMfwSxxI8phr/MbOA==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [linux] + libc: [glibc] + + lightningcss-linux-x64-musl@1.32.0: + resolution: {integrity: sha512-bYcLp+Vb0awsiXg/80uCRezCYHNg1/l3mt0gzHnWV9XP1W5sKa5/TCdGWaR/zBM2PeF/HbsQv/j2URNOiVuxWg==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [linux] + libc: [musl] + + lightningcss-win32-arm64-msvc@1.32.0: + resolution: {integrity: sha512-8SbC8BR40pS6baCM8sbtYDSwEVQd4JlFTOlaD3gWGHfThTcABnNDBda6eTZeqbofalIJhFx0qKzgHJmcPTnGdw==} + engines: {node: '>= 12.0.0'} + cpu: [arm64] + os: [win32] + + lightningcss-win32-x64-msvc@1.32.0: + resolution: {integrity: sha512-Amq9B/SoZYdDi1kFrojnoqPLxYhQ4Wo5XiL8EVJrVsB8ARoC1PWW6VGtT0WKCemjy8aC+louJnjS7U18x3b06Q==} + engines: {node: '>= 12.0.0'} + cpu: [x64] + os: [win32] + + lightningcss@1.32.0: + resolution: {integrity: sha512-NXYBzinNrblfraPGyrbPoD19C1h9lfI/1mzgWYvXUTe414Gz/X1FD2XBZSZM7rRTrMA8JL3OtAaGifrIKhQ5yQ==} + engines: {node: '>= 12.0.0'} + + lodash.merge@4.6.2: + resolution: {integrity: sha512-0KpjqXRVvrYyCsX1swR/XTK0va6VQkQM6MNo7PqW77ByjAhoARA8EfrP1N4+KlKj8YS0ZUCtRT/YUuhyYDujIQ==} + + lodash.mergewith@4.6.2: + resolution: {integrity: sha512-GK3g5RPZWTRSeLSpgP8Xhra+pnjBC56q9FZYe1d5RN3TJ35dbkGy3YqBSMbyCrlbi+CM9Z3Jk5yTL7RCsqboyQ==} + + lodash.sortby@4.7.0: + resolution: {integrity: sha512-HDWXG8isMntAyRF5vZ7xKuEvOhT4AhlRt/3czTSjvGUxjYCBVRQY48ViDHyfYz9VIoBkW4TMGQNapx+l3RUwdA==} + + lower-case@2.0.2: + resolution: {integrity: sha512-7fm3l3NAF9WfN6W3JOmf5drwpVqX78JtoGJ3A6W0a6ZnldM41w2fV5D490psKFTpMds8TJse/eHLFFsNHHjHgg==} + + lru-cache@11.2.7: + resolution: {integrity: sha512-aY/R+aEsRelme17KGQa/1ZSIpLpNYYrhcrepKTZgE+W3WM16YMCaPwOHLHsmopZHELU0Ojin1lPVxKR0MihncA==} + engines: {node: 20 || >=22} + + luxon@3.7.2: + resolution: {integrity: sha512-vtEhXh/gNjI9Yg1u4jX/0YVPMvxzHuGgCm6tC5kZyb08yjGWGnqAjGJvcXbqQR2P3MyMEFnRbpcdFS6PBcLqew==} + engines: {node: '>=12'} + + mime@3.0.0: + resolution: {integrity: sha512-jSCU7/VB1loIWBZe14aEYHU/+1UMEHoaO7qxCOVJOw9GgH72VAWppxNcjU+x9a2k3GSIBXNKxXQFqRvvZ7vr3A==} + engines: {node: '>=10.0.0'} + hasBin: true + + minimatch@10.2.4: + resolution: {integrity: sha512-oRjTw/97aTBN0RHbYCdtF1MQfvusSIBQM0IZEgzl6426+8jSC0nF1a/GmnVLpfB9yyr6g6FTqWqiZVbxrtaCIg==} + engines: {node: 18 || 20 || >=22} + + minipass@7.1.3: + resolution: {integrity: sha512-tEBHqDnIoM/1rXME1zgka9g6Q2lcoCkxHLuc7ODJ5BxbP5d4c2Z5cGgtXAku59200Cx7diuHTOYfSBD8n6mm8A==} + engines: {node: '>=16 || 14 >=14.17'} + + mkdirp@3.0.1: + resolution: {integrity: sha512-+NsyUUAZDmo6YVHzL/stxSu3t9YS1iljliy3BSDrXJ/dkn1KYdmtZODGGjLcc9XLgVVpH4KshHB8XmZgMhaBXg==} + engines: {node: '>=10'} + hasBin: true + + mrmime@2.0.1: + resolution: {integrity: sha512-Y3wQdFg2Va6etvQ5I82yUhGdsKrcYox6p7FfL1LbK2J4V01F9TGlepTIhnK24t7koZibmg82KGglhA1XK5IsLQ==} + engines: {node: '>=10'} + + ms@2.1.3: + resolution: {integrity: sha512-6FlzubTLZG3J2a/NVCAleEhjzq5oxgHyaCU9yYXvcLsvoVaHJq/s5xXI6/XXP6tz7R9xAOtHnSO/tXtF3WRTlA==} + + mz@2.7.0: + resolution: {integrity: sha512-z81GNO7nnYMEhrGh9LeymoE4+Yr0Wn5McHIZMK5cfQCl+NDX08sCZgUc9/6MHni9IWuFLm1Z3HTCXu2z9fN62Q==} + + nanoid@3.3.11: + resolution: {integrity: sha512-N8SpfPUnUp1bK+PMYW8qSWdl9U+wwNWI4QKxOYDy9JAro3WMX7p2OeVRF9v+347pnakNevPmiHhNmZ2HbFA76w==} + engines: {node: ^10 || ^12 || ^13.7 || ^14 || >=15.0.1} + hasBin: true + + no-case@3.0.4: + resolution: {integrity: sha512-fgAN3jGAh+RoxUGZHTSOLJIqUc2wmoBwGR4tbpNAKmmovFoWq0OdRkb0VkldReO2a2iBT/OEulG9XSUc10r3zg==} + + normalize-package-data@8.0.0: + resolution: {integrity: sha512-RWk+PI433eESQ7ounYxIp67CYuVsS1uYSonX3kA6ps/3LWfjVQa/ptEg6Y3T6uAMq1mWpX9PQ+qx+QaHpsc7gQ==} + engines: {node: ^20.17.0 || >=22.9.0} + + object-assign@4.1.1: + resolution: {integrity: sha512-rJgTQnkUnH1sFw8yT6VSU3zD3sWmu6sZhIseY8VX+GRu3P6F7Fu+JNDoXfklElbLJSnc3FUQHVe4cU5hj+BcUg==} + engines: {node: '>=0.10.0'} + + obug@2.1.1: + resolution: {integrity: sha512-uTqF9MuPraAQ+IsnPf366RG4cP9RtUi7MLO1N3KEc+wb0a6yKpeL0lmk2IB1jY5KHPAlTc6T/JRdC/YqxHNwkQ==} + + oxfmt@0.42.0: + resolution: {integrity: sha512-QhejGErLSMReNuZ6vxgFHDyGoPbjTRNi6uGHjy0cvIjOQFqD6xmr/T+3L41ixR3NIgzcNiJ6ylQKpvShTgDfqg==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + + oxlint-tsgolint@0.17.3: + resolution: {integrity: sha512-1eh4bcpOMw0e7+YYVxmhFc2mo/V6hJ2+zfukqf+GprvVn3y94b69M/xNrYLmx5A+VdYe0i/bJ2xOs6Hp/jRmRA==} + hasBin: true + + oxlint@1.57.0: + resolution: {integrity: sha512-DGFsuBX5MFZX9yiDdtKjTrYPq45CZ8Fft6qCltJITYZxfwYjVdGf/6wycGYTACloauwIPxUnYhBVeZbHvleGhw==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + peerDependencies: + oxlint-tsgolint: '>=0.15.0' + peerDependenciesMeta: + oxlint-tsgolint: + optional: true + + pad-right@0.2.2: + resolution: {integrity: sha512-4cy8M95ioIGolCoMmm2cMntGR1lPLEbOMzOKu8bzjuJP6JpzEMQcDHmh7hHLYGgob+nKe1YHFMaG4V59HQa89g==} + engines: {node: '>=0.10.0'} + + parse-json@8.3.0: + resolution: {integrity: sha512-ybiGyvspI+fAoRQbIPRddCcSTV9/LsJbf0e/S85VLowVGzRmokfneg2kwVW/KU5rOXrPSbF1qAKPMgNTqqROQQ==} + engines: {node: '>=18'} + + path-key@3.1.1: + resolution: {integrity: sha512-ojmeN0qd+y0jszEtoY48r0Peq5dwMEkIlCOu6Q5f41lfkswXuKtYrhgoTpLnyIcHm24Uhqx+5Tqm2InSwLhE6Q==} + engines: {node: '>=8'} + + path-scurry@2.0.2: + resolution: {integrity: sha512-3O/iVVsJAPsOnpwWIeD+d6z/7PmqApyQePUtCndjatj/9I5LylHvt5qluFaBT3I5h3r1ejfR056c+FCv+NnNXg==} + engines: {node: 18 || 20 || >=22} + + picocolors@1.1.1: + resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} + + picomatch@4.0.4: + resolution: {integrity: sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==} + engines: {node: '>=12'} + + pixelmatch@7.1.0: + resolution: {integrity: sha512-1wrVzJ2STrpmONHKBy228LM1b84msXDUoAzVEl0R8Mz4Ce6EPr+IVtxm8+yvrqLYMHswREkjYFaMxnyGnaY3Ng==} + hasBin: true + + playwright-core@1.51.1: + resolution: {integrity: sha512-/crRMj8+j/Nq5s8QcvegseuyeZPxpQCZb6HNk3Sos3BlZyAknRjoyJPFWkpNn8v0+P3WiwqFF8P+zQo4eqiNuw==} + engines: {node: '>=18'} + hasBin: true + + playwright@1.51.1: + resolution: {integrity: sha512-kkx+MB2KQRkyxjYPc3a0wLZZoDczmppyGJIvQ43l+aZihkaVvmu/21kiyaHeHjiFxjxNNFnUncKmcGIyOojsaw==} + engines: {node: '>=18'} + hasBin: true + + pngjs@7.0.0: + resolution: {integrity: sha512-LKWqWJRhstyYo9pGvgor/ivk2w94eSjE3RGVuzLGlr3NmD8bf7RcYGze1mNdEHRP6TRP6rMuDHk5t44hnTRyow==} + engines: {node: '>=14.19.0'} + + postcss@8.5.8: + resolution: {integrity: sha512-OW/rX8O/jXnm82Ey1k44pObPtdblfiuWnrd8X7GJ7emImCOstunGbXUpp7HdBrFQX6rJzn3sPT397Wp5aCwCHg==} + engines: {node: ^10 || ^12 || >=14} + + progress@2.0.3: + resolution: {integrity: sha512-7PiHtLll5LdnKIMw100I+8xJXR5gW2QwWYkT6iJva0bXitZKa/XMrSbdmg3r2Xnaidz9Qumd0VPaMrZlF9V9sA==} + engines: {node: '>=0.4.0'} + + property-expr@2.0.6: + resolution: {integrity: sha512-SVtmxhRE/CGkn3eZY1T6pC8Nln6Fr/lu1mKSgRud0eC73whjGfoAogbn78LkD8aFL0zz3bAFerKSnOl7NlErBA==} + + read-package-up@12.0.0: + resolution: {integrity: sha512-Q5hMVBYur/eQNWDdbF4/Wqqr9Bjvtrw2kjGxxBbKLbx8bVCL8gcArjTy8zDUuLGQicftpMuU0riQNcAsbtOVsw==} + engines: {node: '>=20'} + + read-pkg@10.1.0: + resolution: {integrity: sha512-I8g2lArQiP78ll51UeMZojewtYgIRCKCWqZEgOO8c/uefTI+XDXvCSXu3+YNUaTNvZzobrL5+SqHjBrByRRTdg==} + engines: {node: '>=20'} + + reflect-metadata@0.2.2: + resolution: {integrity: sha512-urBwgfrvVP/eAyXx4hluJivBKzuEbSQs9rKWCrCkbSxNv8mxPcUZKeuoF3Uy4mJl3Lwprp6yy5/39VWigZ4K6Q==} + + regexp-match-indices@1.0.2: + resolution: {integrity: sha512-DwZuAkt8NF5mKwGGER1EGh2PRqyvhRhhLviH+R8y8dIuaQROlUfXjt4s9ZTXstIsSkptf06BSvwcEmmfheJJWQ==} + + regexp-tree@0.1.27: + resolution: {integrity: sha512-iETxpjK6YoRWJG5o6hXLwvjYAoW+FEZn9os0PD/b6AP6xQwsa/Y7lCVgIixBbUPMfhu+i2LtdeAqVTgGlQarfA==} + hasBin: true + + repeat-string@1.6.1: + resolution: {integrity: sha512-PV0dzCYDNfRi1jCDbJzpW7jNNDRuCOG/jI5ctQcGKt/clZD+YcPS3yIlWuTJMmESC8aevCFmWJy5wjAFgNqN6w==} + engines: {node: '>=0.10'} + + resolve-pkg-maps@1.0.0: + resolution: {integrity: sha512-seS2Tj26TBVOC2NIc2rOe2y2ZO7efxITtLZcGSOnHHNOQ7CkiUBfw0Iw2ck6xkIhPwLhKNLS8BO+hEpngQlqzw==} + + rolldown@1.0.0-rc.12: + resolution: {integrity: sha512-yP4USLIMYrwpPHEFB5JGH1uxhcslv6/hL0OyvTuY+3qlOSJvZ7ntYnoWpehBxufkgN0cvXxppuTu5hHa/zPh+A==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + + seed-random@2.2.0: + resolution: {integrity: sha512-34EQV6AAHQGhoc0tn/96a9Fsi6v2xdqe/dMUwljGRaFOzR3EgRmECvD0O8vi8X+/uQ50LGHfkNu/Eue5TPKZkQ==} + + semver@7.7.4: + resolution: {integrity: sha512-vFKC2IEtQnVhpT78h1Yp8wzwrf8CM+MzKMHGJZfBtzhZNycRFnXsHk6E5TxIkkMsgNS7mdX3AGB7x2QM2di4lA==} + engines: {node: '>=10'} + hasBin: true + + shebang-command@2.0.0: + resolution: {integrity: sha512-kHxr2zZpYtdmrN1qDjrrX/Z1rR1kG8Dx+gkpK1G4eXmvXswmcE1hTWBWYUzlraYw1/yZp6YuDY77YtvbN0dmDA==} + engines: {node: '>=8'} + + shebang-regex@3.0.0: + resolution: {integrity: sha512-7++dFhtcx3353uBaq8DDR4NuxBetBzC7ZQOhmTQInHEd6bSrXdiEyzCvG07Z44UYdLShWUyXt5M/yhz8ekcb1A==} + engines: {node: '>=8'} + + sirv@3.0.2: + resolution: {integrity: sha512-2wcC/oGxHis/BoHkkPwldgiPSYcpZK3JU28WoMVv55yHJgcZ8rlXvuG9iZggz+sU1d4bRgIGASwyWqjxu3FM0g==} + engines: {node: '>=18'} + + source-map-js@1.2.1: + resolution: {integrity: sha512-UXWMKhLOwVKb728IUtQPXxfYU+usdybtUrK/8uGE8CQMvrhOpwvzDBwj0QhSL7MQc7vIsISBG8VQ8+IDQxpfQA==} + engines: {node: '>=0.10.0'} + + source-map-support@0.5.21: + resolution: {integrity: sha512-uBHU3L3czsIyYXKX88fdrGovxdSCoTGDRZ6SYXtSRxLZUzHg5P/66Ht6uoUlHu9EZod+inXhKo3qQgwXUT/y1w==} + + source-map@0.6.1: + resolution: {integrity: sha512-UjgapumWlbMhkBgzT7Ykc5YXUT46F0iKu8SGXq0bcwP5dz/h0Plj6enJqjz1Zbq2l5WaqYnrVbwWOWMyF3F47g==} + engines: {node: '>=0.10.0'} + + spdx-correct@3.2.0: + resolution: {integrity: sha512-kN9dJbvnySHULIluDHy32WHRUu3Og7B9sbY7tsFLctQkIqnMh3hErYgdMjTYuqmcXX+lK5T1lnUt3G7zNswmZA==} + + spdx-exceptions@2.5.0: + resolution: {integrity: sha512-PiU42r+xO4UbUS1buo3LPJkjlO7430Xn5SVAhdpzzsPHsjbYVflnnFdATgabnLude+Cqu25p6N+g2lw/PFsa4w==} + + spdx-expression-parse@3.0.1: + resolution: {integrity: sha512-cbqHunsQWnJNE6KhVSMsMeH5H/L9EpymbzqTQ3uLwNCLZ1Q481oWaofqH7nO6V07xlXwY6PhQdQ2IedWx/ZK4Q==} + + spdx-license-ids@3.0.23: + resolution: {integrity: sha512-CWLcCCH7VLu13TgOH+r8p1O/Znwhqv/dbb6lqWy67G+pT1kHmeD/+V36AVb/vq8QMIQwVShJ6Ssl5FPh0fuSdw==} + + stackframe@1.3.4: + resolution: {integrity: sha512-oeVtt7eWQS+Na6F//S4kJ2K2VbRlS9D43mAlMyVpVWovy9o+jfgH8O9agzANzaiLjclA0oYzUXEM4PurhSUChw==} + + std-env@4.0.0: + resolution: {integrity: sha512-zUMPtQ/HBY3/50VbpkupYHbRroTRZJPRLvreamgErJVys0ceuzMkD44J/QjqhHjOzK42GQ3QZIeFG1OYfOtKqQ==} + + string-argv@0.3.1: + resolution: {integrity: sha512-a1uQGz7IyVy9YwhqjZIZu1c8JO8dNIe20xBmSS6qu9kv++k3JGzCVmprbNN5Kn+BgzD5E7YYwg1CcjuJMRNsvg==} + engines: {node: '>=0.6.19'} + + string-width@4.2.3: + resolution: {integrity: sha512-wKyQRQpjJ0sIp62ErSZdGsjMJWsap5oRNihHhu6G7JVO/9jIB6UyevL+tXuOqrng8j/cxKTWyWUwvSTriiZz/g==} + engines: {node: '>=8'} + + strip-ansi@6.0.1: + resolution: {integrity: sha512-Y38VPSHcqkFrCpFnQ9vuSXmquuv5oXOKpGeT6aGrr3o3Gc9AlVa6JBfUSOCnbxGGZF+/0ooI7KrPuUSztUdU5A==} + engines: {node: '>=8'} + + supports-color@7.2.0: + resolution: {integrity: sha512-qpCAvRl9stuOHveKsn7HncJRvv501qIacKzQlO/+Lwxc9+0q2wLyv4Dfvt80/DPn2pqOBsJdDiogXGR9+OvwRw==} + engines: {node: '>=8'} + + supports-color@8.1.1: + resolution: {integrity: sha512-MpUEN2OodtUzxvKQl72cUF7RQ5EiHsGvSsVG0ia9c5RbWGL2CI4C7EpPS8UTBIplnlzZiNuV56w+FuNxy3ty2Q==} + engines: {node: '>=10'} + + tagged-tag@1.0.0: + resolution: {integrity: sha512-yEFYrVhod+hdNyx7g5Bnkkb0G6si8HJurOoOEgC8B/O0uXLHlaey/65KRv6cuWBNhBgHKAROVpc7QyYqE5gFng==} + engines: {node: '>=20'} + + thenify-all@1.6.0: + resolution: {integrity: sha512-RNxQH/qI8/t3thXJDwcstUO4zeqo64+Uy/+sNVRBx4Xn2OX+OZ9oP+iJnNFqplFra2ZUVeKCSa2oVWi3T4uVmA==} + engines: {node: '>=0.8'} + + thenify@3.3.1: + resolution: {integrity: sha512-RVZSIV5IG10Hk3enotrhvz0T9em6cyHBLkH/YAZuKqd8hRkKhSfCGIcP2KUY0EPxndzANBmNllzWPwak+bheSw==} + + tiny-case@1.0.3: + resolution: {integrity: sha512-Eet/eeMhkO6TX8mnUteS9zgPbUMQa4I6Kkp5ORiBD5476/m+PIRiumP5tmh5ioJpH7k51Kehawy2UDfsnxxY8Q==} + + tinybench@2.9.0: + resolution: {integrity: sha512-0+DUvqWMValLmha6lr4kD8iAMK1HzV0/aKnCtWb9v9641TnP/MFb7Pc2bxoxQjTXAErryXVgUOfv2YqNllqGeg==} + + tinyexec@1.0.4: + resolution: {integrity: sha512-u9r3uZC0bdpGOXtlxUIdwf9pkmvhqJdrVCH9fapQtgy/OeTTMZ1nqH7agtvEfmGui6e1XxjcdrlxvxJvc3sMqw==} + engines: {node: '>=18'} + + tinyglobby@0.2.15: + resolution: {integrity: sha512-j2Zq4NyQYG5XMST4cbs02Ak8iJUdxRM0XI5QyxXuZOzKOINmWurp3smXu3y5wDcJrptwpSjgXHzIQxR0omXljQ==} + engines: {node: '>=12.0.0'} + + tinypool@2.1.0: + resolution: {integrity: sha512-Pugqs6M0m7Lv1I7FtxN4aoyToKg1C4tu+/381vH35y8oENM/Ai7f7C4StcoK4/+BSw9ebcS8jRiVrORFKCALLw==} + engines: {node: ^20.0.0 || >=22.0.0} + + toposort@2.0.2: + resolution: {integrity: sha512-0a5EOkAUp8D4moMi2W8ZF8jcga7BgZd91O/yabJCFY8az+XSzeGyTKs0Aoo897iV1Nj6guFq8orWDS96z91oGg==} + + totalist@3.0.1: + resolution: {integrity: sha512-sf4i37nQ2LBx4m3wB74y+ubopq6W/dIzXg0FDGjsYnZHVa1Da8FH853wlL2gtUhg+xJXjfk3kUZS3BRoQeoQBQ==} + engines: {node: '>=6'} + + ts-dedent@2.2.0: + resolution: {integrity: sha512-q5W7tVM71e2xjHZTlgfTDoPF/SmqKG5hddq9SzR49CH2hayqRKJtQ4mtRlSxKaJlR/+9rEM+mnBHf7I2/BQcpQ==} + engines: {node: '>=6.10'} + + tslib@2.8.1: + resolution: {integrity: sha512-oJFu94HQb+KVduSUQL7wnpmqnfmLsOA/nAh6b6EH0wCEoK0/mPeXU6c3wKDV83MkOuHPRHtSXKKU99IBazS/2w==} + + tsx@4.21.0: + resolution: {integrity: sha512-5C1sg4USs1lfG0GFb2RLXsdpXqBSEhAaA/0kPL01wxzpMqLILNxIxIOKiILz+cdg/pLnOUxFYOR5yhHU666wbw==} + engines: {node: '>=18.0.0'} + hasBin: true + + type-fest@2.19.0: + resolution: {integrity: sha512-RAH822pAdBgcNMAfWnCBU3CFZcfZ/i1eZjwFU/dsLKumyuuP3niueg2UAukXYF0E2AAoc82ZSSf9J0WQBinzHA==} + engines: {node: '>=12.20'} + + type-fest@4.41.0: + resolution: {integrity: sha512-TeTSQ6H5YHvpqVwBRcnLDCBnDOHWYu7IvGbHT6N8AOymcr9PJGjc1GTtiWZTYg0NCgYwvnYWEkVChQAr9bjfwA==} + engines: {node: '>=16'} + + type-fest@5.5.0: + resolution: {integrity: sha512-PlBfpQwiUvGViBNX84Yxwjsdhd1TUlXr6zjX7eoirtCPIr08NAmxwa+fcYBTeRQxHo9YC9wwF3m9i700sHma8g==} + engines: {node: '>=20'} + + typescript@5.9.3: + resolution: {integrity: sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==} + engines: {node: '>=14.17'} + hasBin: true + + undici-types@7.18.2: + resolution: {integrity: sha512-AsuCzffGHJybSaRrmr5eHr81mwJU3kjw6M+uprWvCXiNeN9SOGwQ3Jn8jb8m3Z6izVgknn1R0FTCEAP2QrLY/w==} + + unicorn-magic@0.4.0: + resolution: {integrity: sha512-wH590V9VNgYH9g3lH9wWjTrUoKsjLF6sGLjhR4sH1LWpLmCOH0Zf7PukhDA8BiS7KHe4oPNkcTHqYkj7SOGUOw==} + engines: {node: '>=20'} + + upper-case-first@2.0.2: + resolution: {integrity: sha512-514ppYHBaKwfJRK/pNC6c/OxfGa0obSnAl106u97Ed0I625Nin96KAjttZF6ZL3e1XLtphxnqrOi9iWgm+u+bg==} + + util-arity@1.1.0: + resolution: {integrity: sha512-kkyIsXKwemfSy8ZEoaIz06ApApnWsk5hQO0vLjZS6UkBiGiW++Jsyb8vSBoc0WKlffGoGs5yYy/j5pp8zckrFA==} + + validate-npm-package-license@3.0.4: + resolution: {integrity: sha512-DpKm2Ui/xN7/HQKCtpZxoRWBhZ9Z0kqtygG8XCgNQ8ZlDnxuQmWhj566j8fN4Cu3/JmbhsDo7fcAJq4s9h27Ew==} + + vite-plus@0.1.14: + resolution: {integrity: sha512-p4pWlpZZNiEsHxPWNdeIU9iuPix3ydm3ficb0dXPggoyIkdotfXtvn2NPX9KwfiQImU72EVEs4+VYBZYNcUYrw==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + + vite@8.0.3: + resolution: {integrity: sha512-B9ifbFudT1TFhfltfaIPgjo9Z3mDynBTJSUYxTjOQruf/zHH+ezCQKcoqO+h7a9Pw9Nm/OtlXAiGT1axBgwqrQ==} + engines: {node: ^20.19.0 || >=22.12.0} + hasBin: true + peerDependencies: + '@types/node': ^20.19.0 || >=22.12.0 + '@vitejs/devtools': ^0.1.0 + esbuild: ^0.27.0 + jiti: '>=1.21.0' + less: ^4.0.0 + sass: ^1.70.0 + sass-embedded: ^1.70.0 + stylus: '>=0.54.8' + sugarss: ^5.0.0 + terser: ^5.16.0 + tsx: ^4.8.1 + yaml: ^2.4.2 + peerDependenciesMeta: + '@types/node': + optional: true + '@vitejs/devtools': + optional: true + esbuild: + optional: true + jiti: + optional: true + less: + optional: true + sass: + optional: true + sass-embedded: + optional: true + stylus: + optional: true + sugarss: + optional: true + terser: + optional: true + tsx: + optional: true + yaml: + optional: true + + which@2.0.2: + resolution: {integrity: sha512-BLI3Tl1TW3Pvl70l3yq3Y64i+awpwXqsGBYWkkqMtnbXgrMD+yj7rhW0kuEDxzJaYXGjEW5ogapKNMEKNMjibA==} + engines: {node: '>= 8'} + hasBin: true + + ws@8.20.0: + resolution: {integrity: sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + + xmlbuilder@15.1.1: + resolution: {integrity: sha512-yMqGBqtXyeN1e3TGYvgNgDVZ3j84W4cwkOXQswghol6APgZWaff9lnbvN7MHYJOiXsvGPXtjTYJEiC9J2wv9Eg==} + engines: {node: '>=8.0'} + + yaml@2.8.3: + resolution: {integrity: sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg==} + engines: {node: '>= 14.6'} + hasBin: true + + yup@1.7.1: + resolution: {integrity: sha512-GKHFX2nXul2/4Dtfxhozv701jLQHdf6J34YDh2cEkpqoo8le5Mg6/LrdseVLrFarmFygZTlfIhHx/QKfb/QWXw==} + +snapshots: + + '@babel/code-frame@7.29.0': + dependencies: + '@babel/helper-validator-identifier': 7.28.5 + js-tokens: 4.0.0 + picocolors: 1.1.1 + + '@babel/helper-validator-identifier@7.28.5': {} + + '@colors/colors@1.5.0': + optional: true + + '@cucumber/ci-environment@13.0.0': {} + + '@cucumber/cucumber-expressions@19.0.0': + dependencies: + regexp-match-indices: 1.0.2 + + '@cucumber/cucumber@12.7.0': + dependencies: + '@cucumber/ci-environment': 13.0.0 + '@cucumber/cucumber-expressions': 19.0.0 + '@cucumber/gherkin': 38.0.0 + '@cucumber/gherkin-streams': 6.0.0(@cucumber/gherkin@38.0.0)(@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1))(@cucumber/messages@32.0.1) + '@cucumber/gherkin-utils': 11.0.0 + '@cucumber/html-formatter': 23.0.0(@cucumber/messages@32.0.1) + '@cucumber/junit-xml-formatter': 0.9.0(@cucumber/messages@32.0.1) + '@cucumber/message-streams': 4.0.1(@cucumber/messages@32.0.1) + '@cucumber/messages': 32.0.1 + '@cucumber/pretty-formatter': 1.0.1(@cucumber/cucumber@12.7.0)(@cucumber/messages@32.0.1) + '@cucumber/tag-expressions': 9.1.0 + assertion-error-formatter: 3.0.0 + capital-case: 1.0.4 + chalk: 4.1.2 + cli-table3: 0.6.5 + commander: 14.0.3 + debug: 4.4.3(supports-color@8.1.1) + error-stack-parser: 2.1.4 + figures: 3.2.0 + glob: 13.0.6 + has-ansi: 4.0.1 + indent-string: 4.0.0 + is-installed-globally: 0.4.0 + is-stream: 2.0.1 + knuth-shuffle-seeded: 1.0.6 + lodash.merge: 4.6.2 + lodash.mergewith: 4.6.2 + luxon: 3.7.2 + mime: 3.0.0 + mkdirp: 3.0.1 + mz: 2.7.0 + progress: 2.0.3 + read-package-up: 12.0.0 + semver: 7.7.4 + string-argv: 0.3.1 + supports-color: 8.1.1 + type-fest: 4.41.0 + util-arity: 1.1.0 + yaml: 2.8.3 + yup: 1.7.1 + + '@cucumber/gherkin-streams@6.0.0(@cucumber/gherkin@38.0.0)(@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1))(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/gherkin': 38.0.0 + '@cucumber/message-streams': 4.0.1(@cucumber/messages@32.0.1) + '@cucumber/messages': 32.0.1 + commander: 14.0.0 + source-map-support: 0.5.21 + + '@cucumber/gherkin-utils@11.0.0': + dependencies: + '@cucumber/gherkin': 38.0.0 + '@cucumber/messages': 32.0.1 + '@teppeis/multimaps': 3.0.0 + commander: 14.0.2 + source-map-support: 0.5.21 + + '@cucumber/gherkin@38.0.0': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/html-formatter@23.0.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/junit-xml-formatter@0.9.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + '@cucumber/query': 14.7.0(@cucumber/messages@32.0.1) + '@teppeis/multimaps': 3.0.0 + luxon: 3.7.2 + xmlbuilder: 15.1.1 + + '@cucumber/message-streams@4.0.1(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + + '@cucumber/messages@32.0.1': + dependencies: + class-transformer: 0.5.1 + reflect-metadata: 0.2.2 + + '@cucumber/pretty-formatter@1.0.1(@cucumber/cucumber@12.7.0)(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/cucumber': 12.7.0 + '@cucumber/messages': 32.0.1 + ansi-styles: 5.2.0 + cli-table3: 0.6.5 + figures: 3.2.0 + ts-dedent: 2.2.0 + + '@cucumber/query@14.7.0(@cucumber/messages@32.0.1)': + dependencies: + '@cucumber/messages': 32.0.1 + '@teppeis/multimaps': 3.0.0 + lodash.sortby: 4.7.0 + + '@cucumber/tag-expressions@9.1.0': {} + + '@emnapi/core@1.9.1': + dependencies: + '@emnapi/wasi-threads': 1.2.0 + tslib: 2.8.1 + optional: true + + '@emnapi/runtime@1.9.1': + dependencies: + tslib: 2.8.1 + optional: true + + '@emnapi/wasi-threads@1.2.0': + dependencies: + tslib: 2.8.1 + optional: true + + '@esbuild/aix-ppc64@0.27.4': + optional: true + + '@esbuild/android-arm64@0.27.4': + optional: true + + '@esbuild/android-arm@0.27.4': + optional: true + + '@esbuild/android-x64@0.27.4': + optional: true + + '@esbuild/darwin-arm64@0.27.4': + optional: true + + '@esbuild/darwin-x64@0.27.4': + optional: true + + '@esbuild/freebsd-arm64@0.27.4': + optional: true + + '@esbuild/freebsd-x64@0.27.4': + optional: true + + '@esbuild/linux-arm64@0.27.4': + optional: true + + '@esbuild/linux-arm@0.27.4': + optional: true + + '@esbuild/linux-ia32@0.27.4': + optional: true + + '@esbuild/linux-loong64@0.27.4': + optional: true + + '@esbuild/linux-mips64el@0.27.4': + optional: true + + '@esbuild/linux-ppc64@0.27.4': + optional: true + + '@esbuild/linux-riscv64@0.27.4': + optional: true + + '@esbuild/linux-s390x@0.27.4': + optional: true + + '@esbuild/linux-x64@0.27.4': + optional: true + + '@esbuild/netbsd-arm64@0.27.4': + optional: true + + '@esbuild/netbsd-x64@0.27.4': + optional: true + + '@esbuild/openbsd-arm64@0.27.4': + optional: true + + '@esbuild/openbsd-x64@0.27.4': + optional: true + + '@esbuild/openharmony-arm64@0.27.4': + optional: true + + '@esbuild/sunos-x64@0.27.4': + optional: true + + '@esbuild/win32-arm64@0.27.4': + optional: true + + '@esbuild/win32-ia32@0.27.4': + optional: true + + '@esbuild/win32-x64@0.27.4': + optional: true + + '@napi-rs/wasm-runtime@1.1.1': + dependencies: + '@emnapi/core': 1.9.1 + '@emnapi/runtime': 1.9.1 + '@tybys/wasm-util': 0.10.1 + optional: true + + '@oxc-project/runtime@0.121.0': {} + + '@oxc-project/types@0.122.0': {} + + '@oxfmt/binding-android-arm-eabi@0.42.0': + optional: true + + '@oxfmt/binding-android-arm64@0.42.0': + optional: true + + '@oxfmt/binding-darwin-arm64@0.42.0': + optional: true + + '@oxfmt/binding-darwin-x64@0.42.0': + optional: true + + '@oxfmt/binding-freebsd-x64@0.42.0': + optional: true + + '@oxfmt/binding-linux-arm-gnueabihf@0.42.0': + optional: true + + '@oxfmt/binding-linux-arm-musleabihf@0.42.0': + optional: true + + '@oxfmt/binding-linux-arm64-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-arm64-musl@0.42.0': + optional: true + + '@oxfmt/binding-linux-ppc64-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-riscv64-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-riscv64-musl@0.42.0': + optional: true + + '@oxfmt/binding-linux-s390x-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-x64-gnu@0.42.0': + optional: true + + '@oxfmt/binding-linux-x64-musl@0.42.0': + optional: true + + '@oxfmt/binding-openharmony-arm64@0.42.0': + optional: true + + '@oxfmt/binding-win32-arm64-msvc@0.42.0': + optional: true + + '@oxfmt/binding-win32-ia32-msvc@0.42.0': + optional: true + + '@oxfmt/binding-win32-x64-msvc@0.42.0': + optional: true + + '@oxlint-tsgolint/darwin-arm64@0.17.3': + optional: true + + '@oxlint-tsgolint/darwin-x64@0.17.3': + optional: true + + '@oxlint-tsgolint/linux-arm64@0.17.3': + optional: true + + '@oxlint-tsgolint/linux-x64@0.17.3': + optional: true + + '@oxlint-tsgolint/win32-arm64@0.17.3': + optional: true + + '@oxlint-tsgolint/win32-x64@0.17.3': + optional: true + + '@oxlint/binding-android-arm-eabi@1.57.0': + optional: true + + '@oxlint/binding-android-arm64@1.57.0': + optional: true + + '@oxlint/binding-darwin-arm64@1.57.0': + optional: true + + '@oxlint/binding-darwin-x64@1.57.0': + optional: true + + '@oxlint/binding-freebsd-x64@1.57.0': + optional: true + + '@oxlint/binding-linux-arm-gnueabihf@1.57.0': + optional: true + + '@oxlint/binding-linux-arm-musleabihf@1.57.0': + optional: true + + '@oxlint/binding-linux-arm64-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-arm64-musl@1.57.0': + optional: true + + '@oxlint/binding-linux-ppc64-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-riscv64-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-riscv64-musl@1.57.0': + optional: true + + '@oxlint/binding-linux-s390x-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-x64-gnu@1.57.0': + optional: true + + '@oxlint/binding-linux-x64-musl@1.57.0': + optional: true + + '@oxlint/binding-openharmony-arm64@1.57.0': + optional: true + + '@oxlint/binding-win32-arm64-msvc@1.57.0': + optional: true + + '@oxlint/binding-win32-ia32-msvc@1.57.0': + optional: true + + '@oxlint/binding-win32-x64-msvc@1.57.0': + optional: true + + '@playwright/test@1.51.1': + dependencies: + playwright: 1.51.1 + + '@polka/url@1.0.0-next.29': {} + + '@rolldown/binding-android-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-darwin-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-darwin-x64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-freebsd-x64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm-gnueabihf@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-arm64-musl@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-ppc64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-s390x-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-x64-gnu@1.0.0-rc.12': + optional: true + + '@rolldown/binding-linux-x64-musl@1.0.0-rc.12': + optional: true + + '@rolldown/binding-openharmony-arm64@1.0.0-rc.12': + optional: true + + '@rolldown/binding-wasm32-wasi@1.0.0-rc.12': + dependencies: + '@napi-rs/wasm-runtime': 1.1.1 + optional: true + + '@rolldown/binding-win32-arm64-msvc@1.0.0-rc.12': + optional: true + + '@rolldown/binding-win32-x64-msvc@1.0.0-rc.12': + optional: true + + '@rolldown/pluginutils@1.0.0-rc.12': {} + + '@standard-schema/spec@1.1.0': {} + + '@teppeis/multimaps@3.0.0': {} + + '@tybys/wasm-util@0.10.1': + dependencies: + tslib: 2.8.1 + optional: true + + '@types/chai@5.2.3': + dependencies: + '@types/deep-eql': 4.0.2 + assertion-error: 2.0.1 + + '@types/deep-eql@4.0.2': {} + + '@types/node@25.5.0': + dependencies: + undici-types: 7.18.2 + + '@types/normalize-package-data@2.4.4': {} + + '@voidzero-dev/vite-plus-core@0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': + dependencies: + '@oxc-project/runtime': 0.121.0 + '@oxc-project/types': 0.122.0 + lightningcss: 1.32.0 + postcss: 8.5.8 + optionalDependencies: + '@types/node': 25.5.0 + esbuild: 0.27.4 + fsevents: 2.3.3 + tsx: 4.21.0 + typescript: 5.9.3 + yaml: 2.8.3 + + '@voidzero-dev/vite-plus-darwin-arm64@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-darwin-x64@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-arm64-gnu@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-arm64-musl@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-linux-x64-musl@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-test@0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3)': + dependencies: + '@standard-schema/spec': 1.1.0 + '@types/chai': 5.2.3 + '@voidzero-dev/vite-plus-core': 0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + es-module-lexer: 1.7.0 + obug: 2.1.1 + pixelmatch: 7.1.0 + pngjs: 7.0.0 + sirv: 3.0.2 + std-env: 4.0.0 + tinybench: 2.9.0 + tinyexec: 1.0.4 + tinyglobby: 0.2.15 + vite: 8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3) + ws: 8.20.0 + optionalDependencies: + '@types/node': 25.5.0 + transitivePeerDependencies: + - '@arethetypeswrong/core' + - '@tsdown/css' + - '@tsdown/exe' + - '@vitejs/devtools' + - bufferutil + - esbuild + - jiti + - less + - publint + - sass + - sass-embedded + - stylus + - sugarss + - terser + - tsx + - typescript + - unplugin-unused + - utf-8-validate + - yaml + + '@voidzero-dev/vite-plus-win32-arm64-msvc@0.1.14': + optional: true + + '@voidzero-dev/vite-plus-win32-x64-msvc@0.1.14': + optional: true + + ansi-regex@4.1.1: {} + + ansi-regex@5.0.1: {} + + ansi-styles@4.3.0: + dependencies: + color-convert: 2.0.1 + + ansi-styles@5.2.0: {} + + any-promise@1.3.0: {} + + assertion-error-formatter@3.0.0: + dependencies: + diff: 4.0.4 + pad-right: 0.2.2 + repeat-string: 1.6.1 + + assertion-error@2.0.1: {} + + balanced-match@4.0.4: {} + + brace-expansion@5.0.5: + dependencies: + balanced-match: 4.0.4 + + buffer-from@1.1.2: {} + + cac@7.0.0: {} + + capital-case@1.0.4: + dependencies: + no-case: 3.0.4 + tslib: 2.8.1 + upper-case-first: 2.0.2 + + chalk@4.1.2: + dependencies: + ansi-styles: 4.3.0 + supports-color: 7.2.0 + + class-transformer@0.5.1: {} + + cli-table3@0.6.5: + dependencies: + string-width: 4.2.3 + optionalDependencies: + '@colors/colors': 1.5.0 + + color-convert@2.0.1: + dependencies: + color-name: 1.1.4 + + color-name@1.1.4: {} + + commander@14.0.0: {} + + commander@14.0.2: {} + + commander@14.0.3: {} + + cross-spawn@7.0.6: + dependencies: + path-key: 3.1.1 + shebang-command: 2.0.0 + which: 2.0.2 + + debug@4.4.3(supports-color@8.1.1): + dependencies: + ms: 2.1.3 + optionalDependencies: + supports-color: 8.1.1 + + detect-libc@2.1.2: {} + + diff@4.0.4: {} + + emoji-regex@8.0.0: {} + + error-stack-parser@2.1.4: + dependencies: + stackframe: 1.3.4 + + es-module-lexer@1.7.0: {} + + esbuild@0.27.4: + optionalDependencies: + '@esbuild/aix-ppc64': 0.27.4 + '@esbuild/android-arm': 0.27.4 + '@esbuild/android-arm64': 0.27.4 + '@esbuild/android-x64': 0.27.4 + '@esbuild/darwin-arm64': 0.27.4 + '@esbuild/darwin-x64': 0.27.4 + '@esbuild/freebsd-arm64': 0.27.4 + '@esbuild/freebsd-x64': 0.27.4 + '@esbuild/linux-arm': 0.27.4 + '@esbuild/linux-arm64': 0.27.4 + '@esbuild/linux-ia32': 0.27.4 + '@esbuild/linux-loong64': 0.27.4 + '@esbuild/linux-mips64el': 0.27.4 + '@esbuild/linux-ppc64': 0.27.4 + '@esbuild/linux-riscv64': 0.27.4 + '@esbuild/linux-s390x': 0.27.4 + '@esbuild/linux-x64': 0.27.4 + '@esbuild/netbsd-arm64': 0.27.4 + '@esbuild/netbsd-x64': 0.27.4 + '@esbuild/openbsd-arm64': 0.27.4 + '@esbuild/openbsd-x64': 0.27.4 + '@esbuild/openharmony-arm64': 0.27.4 + '@esbuild/sunos-x64': 0.27.4 + '@esbuild/win32-arm64': 0.27.4 + '@esbuild/win32-ia32': 0.27.4 + '@esbuild/win32-x64': 0.27.4 + + escape-string-regexp@1.0.5: {} + + fdir@6.5.0(picomatch@4.0.4): + optionalDependencies: + picomatch: 4.0.4 + + figures@3.2.0: + dependencies: + escape-string-regexp: 1.0.5 + + find-up-simple@1.0.1: {} + + fsevents@2.3.2: + optional: true + + fsevents@2.3.3: + optional: true + + get-tsconfig@4.13.7: + dependencies: + resolve-pkg-maps: 1.0.0 + + glob@13.0.6: + dependencies: + minimatch: 10.2.4 + minipass: 7.1.3 + path-scurry: 2.0.2 + + global-dirs@3.0.1: + dependencies: + ini: 2.0.0 + + has-ansi@4.0.1: + dependencies: + ansi-regex: 4.1.1 + + has-flag@4.0.0: {} + + hosted-git-info@9.0.2: + dependencies: + lru-cache: 11.2.7 + + indent-string@4.0.0: {} + + index-to-position@1.2.0: {} + + ini@2.0.0: {} + + is-fullwidth-code-point@3.0.0: {} + + is-installed-globally@0.4.0: + dependencies: + global-dirs: 3.0.1 + is-path-inside: 3.0.3 + + is-path-inside@3.0.3: {} + + is-stream@2.0.1: {} + + isexe@2.0.0: {} + + js-tokens@4.0.0: {} + + knuth-shuffle-seeded@1.0.6: + dependencies: + seed-random: 2.2.0 + + lightningcss-android-arm64@1.32.0: + optional: true + + lightningcss-darwin-arm64@1.32.0: + optional: true + + lightningcss-darwin-x64@1.32.0: + optional: true + + lightningcss-freebsd-x64@1.32.0: + optional: true + + lightningcss-linux-arm-gnueabihf@1.32.0: + optional: true + + lightningcss-linux-arm64-gnu@1.32.0: + optional: true + + lightningcss-linux-arm64-musl@1.32.0: + optional: true + + lightningcss-linux-x64-gnu@1.32.0: + optional: true + + lightningcss-linux-x64-musl@1.32.0: + optional: true + + lightningcss-win32-arm64-msvc@1.32.0: + optional: true + + lightningcss-win32-x64-msvc@1.32.0: + optional: true + + lightningcss@1.32.0: + dependencies: + detect-libc: 2.1.2 + optionalDependencies: + lightningcss-android-arm64: 1.32.0 + lightningcss-darwin-arm64: 1.32.0 + lightningcss-darwin-x64: 1.32.0 + lightningcss-freebsd-x64: 1.32.0 + lightningcss-linux-arm-gnueabihf: 1.32.0 + lightningcss-linux-arm64-gnu: 1.32.0 + lightningcss-linux-arm64-musl: 1.32.0 + lightningcss-linux-x64-gnu: 1.32.0 + lightningcss-linux-x64-musl: 1.32.0 + lightningcss-win32-arm64-msvc: 1.32.0 + lightningcss-win32-x64-msvc: 1.32.0 + + lodash.merge@4.6.2: {} + + lodash.mergewith@4.6.2: {} + + lodash.sortby@4.7.0: {} + + lower-case@2.0.2: + dependencies: + tslib: 2.8.1 + + lru-cache@11.2.7: {} + + luxon@3.7.2: {} + + mime@3.0.0: {} + + minimatch@10.2.4: + dependencies: + brace-expansion: 5.0.5 + + minipass@7.1.3: {} + + mkdirp@3.0.1: {} + + mrmime@2.0.1: {} + + ms@2.1.3: {} + + mz@2.7.0: + dependencies: + any-promise: 1.3.0 + object-assign: 4.1.1 + thenify-all: 1.6.0 + + nanoid@3.3.11: {} + + no-case@3.0.4: + dependencies: + lower-case: 2.0.2 + tslib: 2.8.1 + + normalize-package-data@8.0.0: + dependencies: + hosted-git-info: 9.0.2 + semver: 7.7.4 + validate-npm-package-license: 3.0.4 + + object-assign@4.1.1: {} + + obug@2.1.1: {} + + oxfmt@0.42.0: + dependencies: + tinypool: 2.1.0 + optionalDependencies: + '@oxfmt/binding-android-arm-eabi': 0.42.0 + '@oxfmt/binding-android-arm64': 0.42.0 + '@oxfmt/binding-darwin-arm64': 0.42.0 + '@oxfmt/binding-darwin-x64': 0.42.0 + '@oxfmt/binding-freebsd-x64': 0.42.0 + '@oxfmt/binding-linux-arm-gnueabihf': 0.42.0 + '@oxfmt/binding-linux-arm-musleabihf': 0.42.0 + '@oxfmt/binding-linux-arm64-gnu': 0.42.0 + '@oxfmt/binding-linux-arm64-musl': 0.42.0 + '@oxfmt/binding-linux-ppc64-gnu': 0.42.0 + '@oxfmt/binding-linux-riscv64-gnu': 0.42.0 + '@oxfmt/binding-linux-riscv64-musl': 0.42.0 + '@oxfmt/binding-linux-s390x-gnu': 0.42.0 + '@oxfmt/binding-linux-x64-gnu': 0.42.0 + '@oxfmt/binding-linux-x64-musl': 0.42.0 + '@oxfmt/binding-openharmony-arm64': 0.42.0 + '@oxfmt/binding-win32-arm64-msvc': 0.42.0 + '@oxfmt/binding-win32-ia32-msvc': 0.42.0 + '@oxfmt/binding-win32-x64-msvc': 0.42.0 + + oxlint-tsgolint@0.17.3: + optionalDependencies: + '@oxlint-tsgolint/darwin-arm64': 0.17.3 + '@oxlint-tsgolint/darwin-x64': 0.17.3 + '@oxlint-tsgolint/linux-arm64': 0.17.3 + '@oxlint-tsgolint/linux-x64': 0.17.3 + '@oxlint-tsgolint/win32-arm64': 0.17.3 + '@oxlint-tsgolint/win32-x64': 0.17.3 + + oxlint@1.57.0(oxlint-tsgolint@0.17.3): + optionalDependencies: + '@oxlint/binding-android-arm-eabi': 1.57.0 + '@oxlint/binding-android-arm64': 1.57.0 + '@oxlint/binding-darwin-arm64': 1.57.0 + '@oxlint/binding-darwin-x64': 1.57.0 + '@oxlint/binding-freebsd-x64': 1.57.0 + '@oxlint/binding-linux-arm-gnueabihf': 1.57.0 + '@oxlint/binding-linux-arm-musleabihf': 1.57.0 + '@oxlint/binding-linux-arm64-gnu': 1.57.0 + '@oxlint/binding-linux-arm64-musl': 1.57.0 + '@oxlint/binding-linux-ppc64-gnu': 1.57.0 + '@oxlint/binding-linux-riscv64-gnu': 1.57.0 + '@oxlint/binding-linux-riscv64-musl': 1.57.0 + '@oxlint/binding-linux-s390x-gnu': 1.57.0 + '@oxlint/binding-linux-x64-gnu': 1.57.0 + '@oxlint/binding-linux-x64-musl': 1.57.0 + '@oxlint/binding-openharmony-arm64': 1.57.0 + '@oxlint/binding-win32-arm64-msvc': 1.57.0 + '@oxlint/binding-win32-ia32-msvc': 1.57.0 + '@oxlint/binding-win32-x64-msvc': 1.57.0 + oxlint-tsgolint: 0.17.3 + + pad-right@0.2.2: + dependencies: + repeat-string: 1.6.1 + + parse-json@8.3.0: + dependencies: + '@babel/code-frame': 7.29.0 + index-to-position: 1.2.0 + type-fest: 4.41.0 + + path-key@3.1.1: {} + + path-scurry@2.0.2: + dependencies: + lru-cache: 11.2.7 + minipass: 7.1.3 + + picocolors@1.1.1: {} + + picomatch@4.0.4: {} + + pixelmatch@7.1.0: + dependencies: + pngjs: 7.0.0 + + playwright-core@1.51.1: {} + + playwright@1.51.1: + dependencies: + playwright-core: 1.51.1 + optionalDependencies: + fsevents: 2.3.2 + + pngjs@7.0.0: {} + + postcss@8.5.8: + dependencies: + nanoid: 3.3.11 + picocolors: 1.1.1 + source-map-js: 1.2.1 + + progress@2.0.3: {} + + property-expr@2.0.6: {} + + read-package-up@12.0.0: + dependencies: + find-up-simple: 1.0.1 + read-pkg: 10.1.0 + type-fest: 5.5.0 + + read-pkg@10.1.0: + dependencies: + '@types/normalize-package-data': 2.4.4 + normalize-package-data: 8.0.0 + parse-json: 8.3.0 + type-fest: 5.5.0 + unicorn-magic: 0.4.0 + + reflect-metadata@0.2.2: {} + + regexp-match-indices@1.0.2: + dependencies: + regexp-tree: 0.1.27 + + regexp-tree@0.1.27: {} + + repeat-string@1.6.1: {} + + resolve-pkg-maps@1.0.0: {} + + rolldown@1.0.0-rc.12: + dependencies: + '@oxc-project/types': 0.122.0 + '@rolldown/pluginutils': 1.0.0-rc.12 + optionalDependencies: + '@rolldown/binding-android-arm64': 1.0.0-rc.12 + '@rolldown/binding-darwin-arm64': 1.0.0-rc.12 + '@rolldown/binding-darwin-x64': 1.0.0-rc.12 + '@rolldown/binding-freebsd-x64': 1.0.0-rc.12 + '@rolldown/binding-linux-arm-gnueabihf': 1.0.0-rc.12 + '@rolldown/binding-linux-arm64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-arm64-musl': 1.0.0-rc.12 + '@rolldown/binding-linux-ppc64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-s390x-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-x64-gnu': 1.0.0-rc.12 + '@rolldown/binding-linux-x64-musl': 1.0.0-rc.12 + '@rolldown/binding-openharmony-arm64': 1.0.0-rc.12 + '@rolldown/binding-wasm32-wasi': 1.0.0-rc.12 + '@rolldown/binding-win32-arm64-msvc': 1.0.0-rc.12 + '@rolldown/binding-win32-x64-msvc': 1.0.0-rc.12 + + seed-random@2.2.0: {} + + semver@7.7.4: {} + + shebang-command@2.0.0: + dependencies: + shebang-regex: 3.0.0 + + shebang-regex@3.0.0: {} + + sirv@3.0.2: + dependencies: + '@polka/url': 1.0.0-next.29 + mrmime: 2.0.1 + totalist: 3.0.1 + + source-map-js@1.2.1: {} + + source-map-support@0.5.21: + dependencies: + buffer-from: 1.1.2 + source-map: 0.6.1 + + source-map@0.6.1: {} + + spdx-correct@3.2.0: + dependencies: + spdx-expression-parse: 3.0.1 + spdx-license-ids: 3.0.23 + + spdx-exceptions@2.5.0: {} + + spdx-expression-parse@3.0.1: + dependencies: + spdx-exceptions: 2.5.0 + spdx-license-ids: 3.0.23 + + spdx-license-ids@3.0.23: {} + + stackframe@1.3.4: {} + + std-env@4.0.0: {} + + string-argv@0.3.1: {} + + string-width@4.2.3: + dependencies: + emoji-regex: 8.0.0 + is-fullwidth-code-point: 3.0.0 + strip-ansi: 6.0.1 + + strip-ansi@6.0.1: + dependencies: + ansi-regex: 5.0.1 + + supports-color@7.2.0: + dependencies: + has-flag: 4.0.0 + + supports-color@8.1.1: + dependencies: + has-flag: 4.0.0 + + tagged-tag@1.0.0: {} + + thenify-all@1.6.0: + dependencies: + thenify: 3.3.1 + + thenify@3.3.1: + dependencies: + any-promise: 1.3.0 + + tiny-case@1.0.3: {} + + tinybench@2.9.0: {} + + tinyexec@1.0.4: {} + + tinyglobby@0.2.15: + dependencies: + fdir: 6.5.0(picomatch@4.0.4) + picomatch: 4.0.4 + + tinypool@2.1.0: {} + + toposort@2.0.2: {} + + totalist@3.0.1: {} + + ts-dedent@2.2.0: {} + + tslib@2.8.1: {} + + tsx@4.21.0: + dependencies: + esbuild: 0.27.4 + get-tsconfig: 4.13.7 + optionalDependencies: + fsevents: 2.3.3 + + type-fest@2.19.0: {} + + type-fest@4.41.0: {} + + type-fest@5.5.0: + dependencies: + tagged-tag: 1.0.0 + + typescript@5.9.3: {} + + undici-types@7.18.2: {} + + unicorn-magic@0.4.0: {} + + upper-case-first@2.0.2: + dependencies: + tslib: 2.8.1 + + util-arity@1.1.0: {} + + validate-npm-package-license@3.0.4: + dependencies: + spdx-correct: 3.2.0 + spdx-expression-parse: 3.0.1 + + vite-plus@0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3): + dependencies: + '@oxc-project/types': 0.122.0 + '@voidzero-dev/vite-plus-core': 0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + '@voidzero-dev/vite-plus-test': 0.1.14(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(typescript@5.9.3)(vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3))(yaml@2.8.3) + cac: 7.0.0 + cross-spawn: 7.0.6 + oxfmt: 0.42.0 + oxlint: 1.57.0(oxlint-tsgolint@0.17.3) + oxlint-tsgolint: 0.17.3 + picocolors: 1.1.1 + optionalDependencies: + '@voidzero-dev/vite-plus-darwin-arm64': 0.1.14 + '@voidzero-dev/vite-plus-darwin-x64': 0.1.14 + '@voidzero-dev/vite-plus-linux-arm64-gnu': 0.1.14 + '@voidzero-dev/vite-plus-linux-arm64-musl': 0.1.14 + '@voidzero-dev/vite-plus-linux-x64-gnu': 0.1.14 + '@voidzero-dev/vite-plus-linux-x64-musl': 0.1.14 + '@voidzero-dev/vite-plus-win32-arm64-msvc': 0.1.14 + '@voidzero-dev/vite-plus-win32-x64-msvc': 0.1.14 + transitivePeerDependencies: + - '@arethetypeswrong/core' + - '@edge-runtime/vm' + - '@opentelemetry/api' + - '@tsdown/css' + - '@tsdown/exe' + - '@types/node' + - '@vitejs/devtools' + - '@vitest/ui' + - bufferutil + - esbuild + - happy-dom + - jiti + - jsdom + - less + - publint + - sass + - sass-embedded + - stylus + - sugarss + - terser + - tsx + - typescript + - unplugin-unused + - utf-8-validate + - vite + - yaml + + vite@8.0.3(@types/node@25.5.0)(esbuild@0.27.4)(tsx@4.21.0)(yaml@2.8.3): + dependencies: + lightningcss: 1.32.0 + picomatch: 4.0.4 + postcss: 8.5.8 + rolldown: 1.0.0-rc.12 + tinyglobby: 0.2.15 + optionalDependencies: + '@types/node': 25.5.0 + esbuild: 0.27.4 + fsevents: 2.3.3 + tsx: 4.21.0 + yaml: 2.8.3 + + which@2.0.2: + dependencies: + isexe: 2.0.0 + + ws@8.20.0: {} + + xmlbuilder@15.1.1: {} + + yaml@2.8.3: {} + + yup@1.7.1: + dependencies: + property-expr: 2.0.6 + tiny-case: 1.0.3 + toposort: 2.0.2 + type-fest: 2.19.0 diff --git a/e2e/scripts/common.ts b/e2e/scripts/common.ts new file mode 100644 index 0000000000..bb82121079 --- /dev/null +++ b/e2e/scripts/common.ts @@ -0,0 +1,242 @@ +import { spawn, type ChildProcess } from 'node:child_process' +import { access, copyFile, readFile, writeFile } from 'node:fs/promises' +import net from 'node:net' +import path from 'node:path' +import { fileURLToPath, pathToFileURL } from 'node:url' +import { sleep } from '../support/process' + +type RunCommandOptions = { + command: string + args: string[] + cwd: string + env?: NodeJS.ProcessEnv + stdio?: 'inherit' | 'pipe' +} + +type RunCommandResult = { + exitCode: number + stdout: string + stderr: string +} + +type ForegroundProcessOptions = { + command: string + args: string[] + cwd: string + env?: NodeJS.ProcessEnv +} + +export const rootDir = fileURLToPath(new URL('../..', import.meta.url)) +export const e2eDir = path.join(rootDir, 'e2e') +export const apiDir = path.join(rootDir, 'api') +export const dockerDir = path.join(rootDir, 'docker') +export const webDir = path.join(rootDir, 'web') + +export const middlewareComposeFile = path.join(dockerDir, 'docker-compose.middleware.yaml') +export const middlewareEnvFile = path.join(dockerDir, 'middleware.env') +export const middlewareEnvExampleFile = path.join(dockerDir, 'middleware.env.example') +export const webEnvLocalFile = path.join(webDir, '.env.local') +export const webEnvExampleFile = path.join(webDir, '.env.example') +export const apiEnvExampleFile = path.join(apiDir, 'tests', 'integration_tests', '.env.example') + +const formatCommand = (command: string, args: string[]) => [command, ...args].join(' ') + +export const isMainModule = (metaUrl: string) => { + const entrypoint = process.argv[1] + if (!entrypoint) return false + + return pathToFileURL(entrypoint).href === metaUrl +} + +export const runCommand = async ({ + command, + args, + cwd, + env, + stdio = 'inherit', +}: RunCommandOptions): Promise => { + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + stdio: stdio === 'inherit' ? 'inherit' : 'pipe', + }) + + let stdout = '' + let stderr = '' + + if (stdio === 'pipe') { + childProcess.stdout?.on('data', (chunk: Buffer | string) => { + stdout += chunk.toString() + }) + childProcess.stderr?.on('data', (chunk: Buffer | string) => { + stderr += chunk.toString() + }) + } + + return await new Promise((resolve, reject) => { + childProcess.once('error', reject) + childProcess.once('exit', (code) => { + resolve({ + exitCode: code ?? 1, + stdout, + stderr, + }) + }) + }) +} + +export const runCommandOrThrow = async (options: RunCommandOptions) => { + const result = await runCommand(options) + + if (result.exitCode !== 0) { + throw new Error( + `Command failed (${result.exitCode}): ${formatCommand(options.command, options.args)}`, + ) + } + + return result +} + +const forwardSignalsToChild = (childProcess: ChildProcess) => { + const handleSignal = (signal: NodeJS.Signals) => { + if (childProcess.exitCode === null) childProcess.kill(signal) + } + + const onSigint = () => handleSignal('SIGINT') + const onSigterm = () => handleSignal('SIGTERM') + + process.on('SIGINT', onSigint) + process.on('SIGTERM', onSigterm) + + return () => { + process.off('SIGINT', onSigint) + process.off('SIGTERM', onSigterm) + } +} + +export const runForegroundProcess = async ({ + command, + args, + cwd, + env, +}: ForegroundProcessOptions) => { + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + stdio: 'inherit', + }) + + const cleanupSignals = forwardSignalsToChild(childProcess) + const exitCode = await new Promise((resolve, reject) => { + childProcess.once('error', reject) + childProcess.once('exit', (code) => { + resolve(code ?? 1) + }) + }) + + cleanupSignals() + process.exit(exitCode) +} + +export const ensureFileExists = async (filePath: string, exampleFilePath: string) => { + try { + await access(filePath) + } catch { + await copyFile(exampleFilePath, filePath) + } +} + +export const ensureLineInFile = async (filePath: string, line: string) => { + const fileContent = await readFile(filePath, 'utf8') + const lines = fileContent.split(/\r?\n/) + const assignmentPrefix = line.includes('=') ? `${line.slice(0, line.indexOf('='))}=` : null + + if (lines.includes(line)) return + + if (assignmentPrefix && lines.some((existingLine) => existingLine.startsWith(assignmentPrefix))) + return + + const normalizedContent = fileContent.endsWith('\n') ? fileContent : `${fileContent}\n` + await writeFile(filePath, `${normalizedContent}${line}\n`, 'utf8') +} + +export const ensureWebEnvLocal = async () => { + await ensureFileExists(webEnvLocalFile, webEnvExampleFile) + + const fileContent = await readFile(webEnvLocalFile, 'utf8') + const nextContent = fileContent.replaceAll('http://localhost:5001', 'http://127.0.0.1:5001') + + if (nextContent !== fileContent) await writeFile(webEnvLocalFile, nextContent, 'utf8') +} + +export const readSimpleDotenv = async (filePath: string) => { + const fileContent = await readFile(filePath, 'utf8') + const entries = fileContent + .split(/\r?\n/) + .map((line) => line.trim()) + .filter((line) => line && !line.startsWith('#')) + .map<[string, string]>((line) => { + const separatorIndex = line.indexOf('=') + const key = separatorIndex === -1 ? line : line.slice(0, separatorIndex).trim() + const rawValue = separatorIndex === -1 ? '' : line.slice(separatorIndex + 1).trim() + + if ( + (rawValue.startsWith('"') && rawValue.endsWith('"')) || + (rawValue.startsWith("'") && rawValue.endsWith("'")) + ) { + return [key, rawValue.slice(1, -1)] + } + + return [key, rawValue] + }) + + return Object.fromEntries(entries) +} + +export const waitForCondition = async ({ + check, + description, + intervalMs, + timeoutMs, +}: { + check: () => Promise | boolean + description: string + intervalMs: number + timeoutMs: number +}) => { + const deadline = Date.now() + timeoutMs + + while (Date.now() < deadline) { + if (await check()) return + + await sleep(intervalMs) + } + + throw new Error(`Timed out waiting for ${description} after ${timeoutMs}ms.`) +} + +export const isTcpPortReachable = async (host: string, port: number, timeoutMs = 1_000) => { + return await new Promise((resolve) => { + const socket = net.createConnection({ + host, + port, + }) + + const finish = (result: boolean) => { + socket.removeAllListeners() + socket.destroy() + resolve(result) + } + + socket.setTimeout(timeoutMs) + socket.once('connect', () => finish(true)) + socket.once('timeout', () => finish(false)) + socket.once('error', () => finish(false)) + }) +} diff --git a/e2e/scripts/run-cucumber.ts b/e2e/scripts/run-cucumber.ts new file mode 100644 index 0000000000..39e9157916 --- /dev/null +++ b/e2e/scripts/run-cucumber.ts @@ -0,0 +1,154 @@ +import { mkdir, rm } from 'node:fs/promises' +import path from 'node:path' +import { startWebServer, stopWebServer } from '../support/web-server' +import { waitForUrl, startLoggedProcess, stopManagedProcess } from '../support/process' +import { apiURL, baseURL, reuseExistingWebServer } from '../test-env' +import { e2eDir, isMainModule, runCommand } from './common' +import { resetState, startMiddleware, stopMiddleware } from './setup' + +type RunOptions = { + forwardArgs: string[] + full: boolean + headed: boolean +} + +const parseArgs = (argv: string[]): RunOptions => { + let full = false + let headed = false + const forwardArgs: string[] = [] + + for (let index = 0; index < argv.length; index += 1) { + const arg = argv[index] + + if (arg === '--') { + forwardArgs.push(...argv.slice(index + 1)) + break + } + + if (arg === '--full') { + full = true + continue + } + + if (arg === '--headed') { + headed = true + continue + } + + forwardArgs.push(arg) + } + + return { + forwardArgs, + full, + headed, + } +} + +const hasCustomTags = (forwardArgs: string[]) => + forwardArgs.some((arg) => arg === '--tags' || arg.startsWith('--tags=')) + +const main = async () => { + const { forwardArgs, full, headed } = parseArgs(process.argv.slice(2)) + const startMiddlewareForRun = full + const resetStateForRun = full + + if (resetStateForRun) await resetState() + + if (startMiddlewareForRun) await startMiddleware() + + const cucumberReportDir = path.join(e2eDir, 'cucumber-report') + const logDir = path.join(e2eDir, '.logs') + + await rm(cucumberReportDir, { force: true, recursive: true }) + await mkdir(logDir, { recursive: true }) + + const apiProcess = await startLoggedProcess({ + command: 'npx', + args: ['tsx', './scripts/setup.ts', 'api'], + cwd: e2eDir, + label: 'api server', + logFilePath: path.join(logDir, 'cucumber-api.log'), + }) + + let cleanupPromise: Promise | undefined + const cleanup = async () => { + if (!cleanupPromise) { + cleanupPromise = (async () => { + await stopWebServer() + await stopManagedProcess(apiProcess) + + if (startMiddlewareForRun) { + try { + await stopMiddleware() + } catch { + // Cleanup should continue even if middleware shutdown fails. + } + } + })() + } + + await cleanupPromise + } + + const onTerminate = () => { + void cleanup().finally(() => { + process.exit(1) + }) + } + + process.once('SIGINT', onTerminate) + process.once('SIGTERM', onTerminate) + + try { + try { + await waitForUrl(`${apiURL}/health`, 180_000, 1_000) + } catch { + throw new Error(`API did not become ready at ${apiURL}/health.`) + } + + await startWebServer({ + baseURL, + command: 'npx', + args: ['tsx', './scripts/setup.ts', 'web'], + cwd: e2eDir, + logFilePath: path.join(logDir, 'cucumber-web.log'), + reuseExistingServer: reuseExistingWebServer, + timeoutMs: 300_000, + }) + + const cucumberEnv: NodeJS.ProcessEnv = { + ...process.env, + CUCUMBER_HEADLESS: headed ? '0' : '1', + } + + if (startMiddlewareForRun && !hasCustomTags(forwardArgs)) + cucumberEnv.E2E_CUCUMBER_TAGS = 'not @skip' + + const result = await runCommand({ + command: 'npx', + args: [ + 'tsx', + './node_modules/@cucumber/cucumber/bin/cucumber.js', + '--config', + './cucumber.config.ts', + ...forwardArgs, + ], + cwd: e2eDir, + env: cucumberEnv, + }) + + process.exitCode = result.exitCode + } finally { + process.off('SIGINT', onTerminate) + process.off('SIGTERM', onTerminate) + await cleanup() + } +} + +if (isMainModule(import.meta.url)) { + void main().catch((error) => { + console.error(error instanceof Error ? error.message : String(error)) + process.exit(1) + }) +} diff --git a/e2e/scripts/setup.ts b/e2e/scripts/setup.ts new file mode 100644 index 0000000000..6f38598df4 --- /dev/null +++ b/e2e/scripts/setup.ts @@ -0,0 +1,306 @@ +import { access, mkdir, rm } from 'node:fs/promises' +import path from 'node:path' +import { waitForUrl } from '../support/process' +import { + apiDir, + apiEnvExampleFile, + dockerDir, + e2eDir, + ensureFileExists, + ensureLineInFile, + ensureWebEnvLocal, + isMainModule, + isTcpPortReachable, + middlewareComposeFile, + middlewareEnvExampleFile, + middlewareEnvFile, + readSimpleDotenv, + runCommand, + runCommandOrThrow, + runForegroundProcess, + waitForCondition, + webDir, +} from './common' + +const buildIdPath = path.join(webDir, '.next', 'BUILD_ID') + +const middlewareDataPaths = [ + path.join(dockerDir, 'volumes', 'db', 'data'), + path.join(dockerDir, 'volumes', 'plugin_daemon'), + path.join(dockerDir, 'volumes', 'redis', 'data'), + path.join(dockerDir, 'volumes', 'weaviate'), +] + +const e2eStatePaths = [ + path.join(e2eDir, '.auth'), + path.join(e2eDir, 'cucumber-report'), + path.join(e2eDir, '.logs'), + path.join(e2eDir, 'playwright-report'), + path.join(e2eDir, 'test-results'), +] + +const composeArgs = [ + 'compose', + '-f', + middlewareComposeFile, + '--profile', + 'postgresql', + '--profile', + 'weaviate', +] + +const getApiEnvironment = async () => { + const envFromExample = await readSimpleDotenv(apiEnvExampleFile) + + return { + ...envFromExample, + FLASK_APP: 'app.py', + } +} + +const getServiceContainerId = async (service: string) => { + const result = await runCommandOrThrow({ + command: 'docker', + args: ['compose', '-f', middlewareComposeFile, 'ps', '-q', service], + cwd: dockerDir, + stdio: 'pipe', + }) + + return result.stdout.trim() +} + +const getContainerHealth = async (containerId: string) => { + const result = await runCommand({ + command: 'docker', + args: ['inspect', '-f', '{{.State.Health.Status}}', containerId], + cwd: dockerDir, + stdio: 'pipe', + }) + + if (result.exitCode !== 0) return '' + + return result.stdout.trim() +} + +const printComposeLogs = async (services: string[]) => { + await runCommand({ + command: 'docker', + args: ['compose', '-f', middlewareComposeFile, 'logs', ...services], + cwd: dockerDir, + }) +} + +const waitForDependency = async ({ + description, + services, + wait, +}: { + description: string + services: string[] + wait: () => Promise +}) => { + console.log(`Waiting for ${description}...`) + + try { + await wait() + } catch (error) { + await printComposeLogs(services) + throw error + } +} + +export const ensureWebBuild = async () => { + await ensureWebEnvLocal() + + if (process.env.E2E_FORCE_WEB_BUILD === '1') { + await runCommandOrThrow({ + command: 'pnpm', + args: ['run', 'build'], + cwd: webDir, + }) + return + } + + try { + await access(buildIdPath) + console.log('Reusing existing web build artifact.') + } catch { + await runCommandOrThrow({ + command: 'pnpm', + args: ['run', 'build'], + cwd: webDir, + }) + } +} + +export const startWeb = async () => { + await ensureWebBuild() + + await runForegroundProcess({ + command: 'pnpm', + args: ['run', 'start'], + cwd: webDir, + env: { + HOSTNAME: '127.0.0.1', + PORT: '3000', + }, + }) +} + +export const startApi = async () => { + const env = await getApiEnvironment() + + await runCommandOrThrow({ + command: 'uv', + args: ['run', '--project', '.', 'flask', 'upgrade-db'], + cwd: apiDir, + env, + }) + + await runForegroundProcess({ + command: 'uv', + args: ['run', '--project', '.', 'flask', 'run', '--host', '127.0.0.1', '--port', '5001'], + cwd: apiDir, + env, + }) +} + +export const stopMiddleware = async () => { + await runCommandOrThrow({ + command: 'docker', + args: [...composeArgs, 'down', '--remove-orphans'], + cwd: dockerDir, + }) +} + +export const resetState = async () => { + console.log('Stopping middleware services...') + try { + await stopMiddleware() + } catch { + // Reset should continue even if middleware is already stopped. + } + + console.log('Removing persisted middleware data...') + await Promise.all( + middlewareDataPaths.map(async (targetPath) => { + await rm(targetPath, { force: true, recursive: true }) + await mkdir(targetPath, { recursive: true }) + }), + ) + + console.log('Removing E2E local state...') + await Promise.all( + e2eStatePaths.map((targetPath) => rm(targetPath, { force: true, recursive: true })), + ) + + console.log('E2E state reset complete.') +} + +export const startMiddleware = async () => { + await ensureFileExists(middlewareEnvFile, middlewareEnvExampleFile) + await ensureLineInFile(middlewareEnvFile, 'COMPOSE_PROFILES=postgresql,weaviate') + + console.log('Starting middleware services...') + await runCommandOrThrow({ + command: 'docker', + args: [ + ...composeArgs, + 'up', + '-d', + 'db_postgres', + 'redis', + 'weaviate', + 'sandbox', + 'ssrf_proxy', + 'plugin_daemon', + ], + cwd: dockerDir, + }) + + const [postgresContainerId, redisContainerId] = await Promise.all([ + getServiceContainerId('db_postgres'), + getServiceContainerId('redis'), + ]) + + await waitForDependency({ + description: 'PostgreSQL and Redis health checks', + services: ['db_postgres', 'redis'], + wait: () => + waitForCondition({ + check: async () => { + const [postgresStatus, redisStatus] = await Promise.all([ + getContainerHealth(postgresContainerId), + getContainerHealth(redisContainerId), + ]) + + return postgresStatus === 'healthy' && redisStatus === 'healthy' + }, + description: 'PostgreSQL and Redis health checks', + intervalMs: 2_000, + timeoutMs: 240_000, + }), + }) + + await waitForDependency({ + description: 'Weaviate readiness', + services: ['weaviate'], + wait: () => waitForUrl('http://127.0.0.1:8080/v1/.well-known/ready', 120_000, 2_000), + }) + + await waitForDependency({ + description: 'sandbox health', + services: ['sandbox', 'ssrf_proxy'], + wait: () => waitForUrl('http://127.0.0.1:8194/health', 120_000, 2_000), + }) + + await waitForDependency({ + description: 'plugin daemon port', + services: ['plugin_daemon'], + wait: () => + waitForCondition({ + check: async () => isTcpPortReachable('127.0.0.1', 5002), + description: 'plugin daemon port', + intervalMs: 2_000, + timeoutMs: 120_000, + }), + }) + + console.log('Full middleware stack is ready.') +} + +const printUsage = () => { + console.log('Usage: tsx ./scripts/setup.ts ') +} + +const main = async () => { + const command = process.argv[2] + + switch (command) { + case 'api': + await startApi() + return + case 'middleware-down': + await stopMiddleware() + return + case 'middleware-up': + await startMiddleware() + return + case 'reset': + await resetState() + return + case 'web': + await startWeb() + return + default: + printUsage() + process.exitCode = 1 + } +} + +if (isMainModule(import.meta.url)) { + void main().catch((error) => { + console.error(error instanceof Error ? error.message : String(error)) + process.exit(1) + }) +} diff --git a/e2e/support/process.ts b/e2e/support/process.ts new file mode 100644 index 0000000000..96273ef931 --- /dev/null +++ b/e2e/support/process.ts @@ -0,0 +1,178 @@ +import type { ChildProcess } from 'node:child_process' +import { spawn } from 'node:child_process' +import { createWriteStream, type WriteStream } from 'node:fs' +import { mkdir } from 'node:fs/promises' +import net from 'node:net' +import { dirname } from 'node:path' + +type ManagedProcessOptions = { + command: string + args?: string[] + cwd: string + env?: NodeJS.ProcessEnv + label: string + logFilePath: string +} + +export type ManagedProcess = { + childProcess: ChildProcess + label: string + logFilePath: string + logStream: WriteStream +} + +export const sleep = (ms: number) => + new Promise((resolve) => { + setTimeout(resolve, ms) + }) + +export const isPortReachable = async (host: string, port: number, timeoutMs = 1_000) => { + return await new Promise((resolve) => { + const socket = net.createConnection({ + host, + port, + }) + + const finish = (result: boolean) => { + socket.removeAllListeners() + socket.destroy() + resolve(result) + } + + socket.setTimeout(timeoutMs) + socket.once('connect', () => finish(true)) + socket.once('timeout', () => finish(false)) + socket.once('error', () => finish(false)) + }) +} + +export const waitForUrl = async ( + url: string, + timeoutMs: number, + intervalMs = 1_000, + requestTimeoutMs = Math.max(intervalMs, 1_000), +) => { + const deadline = Date.now() + timeoutMs + + while (Date.now() < deadline) { + try { + const controller = new AbortController() + const timeout = setTimeout(() => controller.abort(), requestTimeoutMs) + + try { + const response = await fetch(url, { + signal: controller.signal, + }) + if (response.ok) return + } finally { + clearTimeout(timeout) + } + } catch { + // Keep polling until timeout. + } + + await sleep(intervalMs) + } + + throw new Error(`Timed out waiting for ${url} after ${timeoutMs}ms.`) +} + +export const startLoggedProcess = async ({ + command, + args = [], + cwd, + env, + label, + logFilePath, +}: ManagedProcessOptions): Promise => { + await mkdir(dirname(logFilePath), { recursive: true }) + + const logStream = createWriteStream(logFilePath, { flags: 'a' }) + const childProcess = spawn(command, args, { + cwd, + env: { + ...process.env, + ...env, + }, + detached: process.platform !== 'win32', + stdio: ['ignore', 'pipe', 'pipe'], + }) + + const formattedCommand = [command, ...args].join(' ') + logStream.write(`[${new Date().toISOString()}] Starting ${label}: ${formattedCommand}\n`) + childProcess.stdout?.pipe(logStream, { end: false }) + childProcess.stderr?.pipe(logStream, { end: false }) + + return { + childProcess, + label, + logFilePath, + logStream, + } +} + +const waitForProcessExit = (childProcess: ChildProcess, timeoutMs: number) => + new Promise((resolve) => { + if (childProcess.exitCode !== null) { + resolve() + return + } + + const timeout = setTimeout(() => { + cleanup() + resolve() + }, timeoutMs) + + const onExit = () => { + cleanup() + resolve() + } + + const cleanup = () => { + clearTimeout(timeout) + childProcess.off('exit', onExit) + } + + childProcess.once('exit', onExit) + }) + +const signalManagedProcess = (childProcess: ChildProcess, signal: NodeJS.Signals) => { + const { pid } = childProcess + if (!pid) return + + try { + if (process.platform !== 'win32') { + process.kill(-pid, signal) + return + } + + childProcess.kill(signal) + } catch { + // Best-effort shutdown. Cleanup continues even when the process is already gone. + } +} + +export const stopManagedProcess = async (managedProcess?: ManagedProcess) => { + if (!managedProcess) return + + const { childProcess, logStream } = managedProcess + + if (childProcess.exitCode === null) { + signalManagedProcess(childProcess, 'SIGTERM') + await waitForProcessExit(childProcess, 5_000) + } + + if (childProcess.exitCode === null) { + signalManagedProcess(childProcess, 'SIGKILL') + await waitForProcessExit(childProcess, 5_000) + } + + childProcess.stdout?.unpipe(logStream) + childProcess.stderr?.unpipe(logStream) + childProcess.stdout?.destroy() + childProcess.stderr?.destroy() + + await new Promise((resolve) => { + logStream.end(() => resolve()) + }) +} diff --git a/e2e/support/web-server.ts b/e2e/support/web-server.ts new file mode 100644 index 0000000000..ad5d5d916a --- /dev/null +++ b/e2e/support/web-server.ts @@ -0,0 +1,83 @@ +import type { ManagedProcess } from './process' +import { isPortReachable, startLoggedProcess, stopManagedProcess, waitForUrl } from './process' + +type WebServerStartOptions = { + baseURL: string + command: string + args?: string[] + cwd: string + logFilePath: string + reuseExistingServer: boolean + timeoutMs: number +} + +let activeProcess: ManagedProcess | undefined + +const getUrlHostAndPort = (url: string) => { + const parsedUrl = new URL(url) + const isHttps = parsedUrl.protocol === 'https:' + + return { + host: parsedUrl.hostname, + port: parsedUrl.port ? Number(parsedUrl.port) : isHttps ? 443 : 80, + } +} + +export const startWebServer = async ({ + baseURL, + command, + args = [], + cwd, + logFilePath, + reuseExistingServer, + timeoutMs, +}: WebServerStartOptions) => { + const { host, port } = getUrlHostAndPort(baseURL) + + if (reuseExistingServer && (await isPortReachable(host, port))) return + + activeProcess = await startLoggedProcess({ + command, + args, + cwd, + label: 'web server', + logFilePath, + }) + + let startupError: Error | undefined + activeProcess.childProcess.once('error', (error) => { + startupError = error + }) + activeProcess.childProcess.once('exit', (code, signal) => { + if (startupError) return + + startupError = new Error( + `Web server exited before readiness (code: ${code ?? 'unknown'}, signal: ${signal ?? 'none'}).`, + ) + }) + + const deadline = Date.now() + timeoutMs + while (Date.now() < deadline) { + if (startupError) { + await stopManagedProcess(activeProcess) + activeProcess = undefined + throw startupError + } + + try { + await waitForUrl(baseURL, 1_000, 250, 1_000) + return + } catch { + // Continue polling until timeout or child exit. + } + } + + await stopManagedProcess(activeProcess) + activeProcess = undefined + throw new Error(`Timed out waiting for web server readiness at ${baseURL} after ${timeoutMs}ms.`) +} + +export const stopWebServer = async () => { + await stopManagedProcess(activeProcess) + activeProcess = undefined +} diff --git a/e2e/test-env.ts b/e2e/test-env.ts new file mode 100644 index 0000000000..c0afc2a8c1 --- /dev/null +++ b/e2e/test-env.ts @@ -0,0 +1,12 @@ +export const defaultBaseURL = 'http://127.0.0.1:3000' +export const defaultApiURL = 'http://127.0.0.1:5001' +export const defaultLocale = 'en-US' + +export const baseURL = process.env.E2E_BASE_URL || defaultBaseURL +export const apiURL = process.env.E2E_API_URL || defaultApiURL + +export const cucumberHeadless = process.env.CUCUMBER_HEADLESS !== '0' +export const cucumberSlowMo = Number(process.env.E2E_SLOW_MO || 0) +export const reuseExistingWebServer = process.env.E2E_REUSE_WEB_SERVER + ? process.env.E2E_REUSE_WEB_SERVER !== '0' + : !process.env.CI diff --git a/e2e/tsconfig.json b/e2e/tsconfig.json new file mode 100644 index 0000000000..3976c12b66 --- /dev/null +++ b/e2e/tsconfig.json @@ -0,0 +1,25 @@ +{ + "compilerOptions": { + "target": "ES2023", + "lib": ["ES2023", "DOM"], + "module": "ESNext", + "moduleResolution": "Bundler", + "allowJs": false, + "resolveJsonModule": true, + "noEmit": true, + "strict": true, + "skipLibCheck": true, + "types": ["node", "@playwright/test", "@cucumber/cucumber"], + "isolatedModules": true, + "verbatimModuleSyntax": true + }, + "include": ["./**/*.ts"], + "exclude": [ + "./node_modules", + "./playwright-report", + "./test-results", + "./.auth", + "./cucumber-report", + "./.logs" + ] +} diff --git a/e2e/vite.config.ts b/e2e/vite.config.ts new file mode 100644 index 0000000000..98400d5b9b --- /dev/null +++ b/e2e/vite.config.ts @@ -0,0 +1,15 @@ +import { defineConfig } from 'vite-plus' + +export default defineConfig({ + lint: { + options: { + typeAware: true, + typeCheck: true, + denyWarnings: true, + }, + }, + fmt: { + singleQuote: true, + semi: false, + }, +}) diff --git a/sdks/nodejs-client/package.json b/sdks/nodejs-client/package.json index 728aa0d054..7168d33c24 100644 --- a/sdks/nodejs-client/package.json +++ b/sdks/nodejs-client/package.json @@ -70,7 +70,8 @@ "pnpm": { "overrides": { "flatted@<=3.4.1": "3.4.2", - "rollup@>=4.0.0,<4.59.0": "4.59.0" + "picomatch@>=4.0.0 <4.0.4": "4.0.4", + "rollup@>=4.0.0 <4.59.0": "4.59.0" } } } diff --git a/sdks/nodejs-client/pnpm-lock.yaml b/sdks/nodejs-client/pnpm-lock.yaml index c9081420f5..30d3cf61ee 100644 --- a/sdks/nodejs-client/pnpm-lock.yaml +++ b/sdks/nodejs-client/pnpm-lock.yaml @@ -6,7 +6,8 @@ settings: overrides: flatted@<=3.4.1: 3.4.2 - rollup@>=4.0.0,<4.59.0: 4.59.0 + picomatch@>=4.0.0 <4.0.4: 4.0.4 + rollup@>=4.0.0 <4.59.0: 4.59.0 importers: @@ -325,79 +326,66 @@ packages: resolution: {integrity: sha512-t4ONHboXi/3E0rT6OZl1pKbl2Vgxf9vJfWgmUoCEVQVxhW6Cw/c8I6hbbu7DAvgp82RKiH7TpLwxnJeKv2pbsw==} cpu: [arm] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm-musleabihf@4.59.0': resolution: {integrity: sha512-CikFT7aYPA2ufMD086cVORBYGHffBo4K8MQ4uPS/ZnY54GKj36i196u8U+aDVT2LX4eSMbyHtyOh7D7Zvk2VvA==} cpu: [arm] os: [linux] - libc: [musl] '@rollup/rollup-linux-arm64-gnu@4.59.0': resolution: {integrity: sha512-jYgUGk5aLd1nUb1CtQ8E+t5JhLc9x5WdBKew9ZgAXg7DBk0ZHErLHdXM24rfX+bKrFe+Xp5YuJo54I5HFjGDAA==} cpu: [arm64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-arm64-musl@4.59.0': resolution: {integrity: sha512-peZRVEdnFWZ5Bh2KeumKG9ty7aCXzzEsHShOZEFiCQlDEepP1dpUl/SrUNXNg13UmZl+gzVDPsiCwnV1uI0RUA==} cpu: [arm64] os: [linux] - libc: [musl] '@rollup/rollup-linux-loong64-gnu@4.59.0': resolution: {integrity: sha512-gbUSW/97f7+r4gHy3Jlup8zDG190AuodsWnNiXErp9mT90iCy9NKKU0Xwx5k8VlRAIV2uU9CsMnEFg/xXaOfXg==} cpu: [loong64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-loong64-musl@4.59.0': resolution: {integrity: sha512-yTRONe79E+o0FWFijasoTjtzG9EBedFXJMl888NBEDCDV9I2wGbFFfJQQe63OijbFCUZqxpHz1GzpbtSFikJ4Q==} cpu: [loong64] os: [linux] - libc: [musl] '@rollup/rollup-linux-ppc64-gnu@4.59.0': resolution: {integrity: sha512-sw1o3tfyk12k3OEpRddF68a1unZ5VCN7zoTNtSn2KndUE+ea3m3ROOKRCZxEpmT9nsGnogpFP9x6mnLTCaoLkA==} cpu: [ppc64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-ppc64-musl@4.59.0': resolution: {integrity: sha512-+2kLtQ4xT3AiIxkzFVFXfsmlZiG5FXYW7ZyIIvGA7Bdeuh9Z0aN4hVyXS/G1E9bTP/vqszNIN/pUKCk/BTHsKA==} cpu: [ppc64] os: [linux] - libc: [musl] '@rollup/rollup-linux-riscv64-gnu@4.59.0': resolution: {integrity: sha512-NDYMpsXYJJaj+I7UdwIuHHNxXZ/b/N2hR15NyH3m2qAtb/hHPA4g4SuuvrdxetTdndfj9b1WOmy73kcPRoERUg==} cpu: [riscv64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-riscv64-musl@4.59.0': resolution: {integrity: sha512-nLckB8WOqHIf1bhymk+oHxvM9D3tyPndZH8i8+35p/1YiVoVswPid2yLzgX7ZJP0KQvnkhM4H6QZ5m0LzbyIAg==} cpu: [riscv64] os: [linux] - libc: [musl] '@rollup/rollup-linux-s390x-gnu@4.59.0': resolution: {integrity: sha512-oF87Ie3uAIvORFBpwnCvUzdeYUqi2wY6jRFWJAy1qus/udHFYIkplYRW+wo+GRUP4sKzYdmE1Y3+rY5Gc4ZO+w==} cpu: [s390x] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-gnu@4.59.0': resolution: {integrity: sha512-3AHmtQq/ppNuUspKAlvA8HtLybkDflkMuLK4DPo77DfthRb71V84/c4MlWJXixZz4uruIH4uaa07IqoAkG64fg==} cpu: [x64] os: [linux] - libc: [glibc] '@rollup/rollup-linux-x64-musl@4.59.0': resolution: {integrity: sha512-2UdiwS/9cTAx7qIUZB/fWtToJwvt0Vbo0zmnYt7ED35KPg13Q0ym1g442THLC7VyI6JfYTP4PiSOWyoMdV2/xg==} cpu: [x64] os: [linux] - libc: [musl] '@rollup/rollup-openbsd-x64@4.59.0': resolution: {integrity: sha512-M3bLRAVk6GOwFlPTIxVBSYKUaqfLrn8l0psKinkCFxl4lQvOSz8ZrKDz2gxcBwHFpci0B6rttydI4IpS4IS/jQ==} @@ -735,7 +723,7 @@ packages: resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==} engines: {node: '>=12.0.0'} peerDependencies: - picomatch: ^3 || ^4 + picomatch: 4.0.4 peerDependenciesMeta: picomatch: optional: true @@ -963,8 +951,8 @@ packages: picocolors@1.1.1: resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} - picomatch@4.0.3: - resolution: {integrity: sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==} + picomatch@4.0.4: + resolution: {integrity: sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==} engines: {node: '>=12'} pirates@4.0.7: @@ -1829,9 +1817,9 @@ snapshots: fast-levenshtein@2.0.6: {} - fdir@6.5.0(picomatch@4.0.3): + fdir@6.5.0(picomatch@4.0.4): optionalDependencies: - picomatch: 4.0.3 + picomatch: 4.0.4 file-entry-cache@8.0.0: dependencies: @@ -2038,7 +2026,7 @@ snapshots: picocolors@1.1.1: {} - picomatch@4.0.3: {} + picomatch@4.0.4: {} pirates@4.0.7: {} @@ -2149,8 +2137,8 @@ snapshots: tinyglobby@0.2.15: dependencies: - fdir: 6.5.0(picomatch@4.0.3) - picomatch: 4.0.3 + fdir: 6.5.0(picomatch@4.0.4) + picomatch: 4.0.4 tinyrainbow@3.0.3: {} @@ -2207,8 +2195,8 @@ snapshots: vite@7.3.1(@types/node@25.4.0): dependencies: esbuild: 0.27.3 - fdir: 6.5.0(picomatch@4.0.3) - picomatch: 4.0.3 + fdir: 6.5.0(picomatch@4.0.4) + picomatch: 4.0.4 postcss: 8.5.8 rollup: 4.59.0 tinyglobby: 0.2.15 @@ -2230,7 +2218,7 @@ snapshots: magic-string: 0.30.21 obug: 2.1.1 pathe: 2.0.3 - picomatch: 4.0.3 + picomatch: 4.0.4 std-env: 3.10.0 tinybench: 2.9.0 tinyexec: 1.0.2 diff --git a/web/.env.example b/web/.env.example index 079c3bdeef..62d4fa6c56 100644 --- a/web/.env.example +++ b/web/.env.example @@ -51,8 +51,6 @@ NEXT_PUBLIC_ALLOW_EMBED= # Allow rendering unsafe URLs which have "data:" scheme. NEXT_PUBLIC_ALLOW_UNSAFE_DATA_SCHEME=false -# Github Access Token, used for invoking Github API -NEXT_PUBLIC_GITHUB_ACCESS_TOKEN= # The maximum number of top-k value for RAG. NEXT_PUBLIC_TOP_K_MAX_VALUE=10 diff --git a/web/__tests__/check-i18n.test.ts b/web/__tests__/check-i18n.test.ts deleted file mode 100644 index 9e9b3d7168..0000000000 --- a/web/__tests__/check-i18n.test.ts +++ /dev/null @@ -1,860 +0,0 @@ -import fs from 'node:fs' -import path from 'node:path' -import vm from 'node:vm' -import { transpile } from 'typescript' - -describe('i18n:check script functionality', () => { - const testDir = path.join(__dirname, '../i18n-test') - const testEnDir = path.join(testDir, 'en-US') - const testZhDir = path.join(testDir, 'zh-Hans') - - // Helper function that replicates the getKeysFromLanguage logic - async function getKeysFromLanguage(language: string, testPath = testDir): Promise { - return new Promise((resolve, reject) => { - const folderPath = path.resolve(testPath, language) - const allKeys: string[] = [] - - if (!fs.existsSync(folderPath)) { - resolve([]) - return - } - - fs.readdir(folderPath, (err, files) => { - if (err) { - reject(err) - return - } - - const translationFiles = files.filter(file => /\.(ts|js)$/.test(file)) - - translationFiles.forEach((file) => { - const filePath = path.join(folderPath, file) - const fileName = file.replace(/\.[^/.]+$/, '') - const camelCaseFileName = fileName.replace(/[-_](.)/g, (_, c) => - c.toUpperCase()) - - try { - const content = fs.readFileSync(filePath, 'utf8') - const moduleExports = {} - const context = { - exports: moduleExports, - module: { exports: moduleExports }, - require, - console, - __filename: filePath, - __dirname: folderPath, - } - - vm.runInNewContext(transpile(content), context) - const translationObj = (context.module.exports as any).default || context.module.exports - - if (!translationObj || typeof translationObj !== 'object') - throw new Error(`Error parsing file: ${filePath}`) - - const nestedKeys: string[] = [] - const iterateKeys = (obj: any, prefix = '') => { - for (const key in obj) { - const nestedKey = prefix ? `${prefix}.${key}` : key - if (typeof obj[key] === 'object' && obj[key] !== null && !Array.isArray(obj[key])) { - // This is an object (but not array), recurse into it but don't add it as a key - iterateKeys(obj[key], nestedKey) - } - else { - // This is a leaf node (string, number, boolean, array, etc.), add it as a key - nestedKeys.push(nestedKey) - } - } - } - iterateKeys(translationObj) - - const fileKeys = nestedKeys.map(key => `${camelCaseFileName}.${key}`) - allKeys.push(...fileKeys) - } - catch (error) { - reject(error) - } - }) - resolve(allKeys) - }) - }) - } - - beforeEach(() => { - // Clean up and create test directories - if (fs.existsSync(testDir)) - fs.rmSync(testDir, { recursive: true }) - - fs.mkdirSync(testDir, { recursive: true }) - fs.mkdirSync(testEnDir, { recursive: true }) - fs.mkdirSync(testZhDir, { recursive: true }) - }) - - afterEach(() => { - // Clean up test files - if (fs.existsSync(testDir)) - fs.rmSync(testDir, { recursive: true }) - }) - - describe('Key extraction logic', () => { - it('should extract only leaf node keys, not intermediate objects', async () => { - const testContent = `const translation = { - simple: 'Simple Value', - nested: { - level1: 'Level 1 Value', - deep: { - level2: 'Level 2 Value' - } - }, - array: ['not extracted'], - number: 42, - boolean: true -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), testContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toEqual([ - 'test.simple', - 'test.nested.level1', - 'test.nested.deep.level2', - 'test.array', - 'test.number', - 'test.boolean', - ]) - - // Should not include intermediate object keys - expect(keys).not.toContain('test.nested') - expect(keys).not.toContain('test.nested.deep') - }) - - it('should handle camelCase file name conversion correctly', async () => { - const testContent = `const translation = { - key: 'value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), testContent) - fs.writeFileSync(path.join(testEnDir, 'user_profile.ts'), testContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('appDebug.key') - expect(keys).toContain('userProfile.key') - }) - }) - - describe('Missing keys detection', () => { - it('should detect missing keys in target language', async () => { - const enContent = `const translation = { - common: { - save: 'Save', - cancel: 'Cancel', - delete: 'Delete' - }, - app: { - title: 'My App', - version: '1.0' - } -} - -export default translation -` - - const zhContent = `const translation = { - common: { - save: '保存', - cancel: '取消' - // missing 'delete' - }, - app: { - title: '我的应用' - // missing 'version' - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) - fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - - expect(missingKeys).toContain('test.common.delete') - expect(missingKeys).toContain('test.app.version') - expect(missingKeys).toHaveLength(2) - }) - }) - - describe('Extra keys detection', () => { - it('should detect extra keys in target language', async () => { - const enContent = `const translation = { - common: { - save: 'Save', - cancel: 'Cancel' - } -} - -export default translation -` - - const zhContent = `const translation = { - common: { - save: '保存', - cancel: '取消', - delete: '删除', // extra key - extra: '额外的' // another extra key - }, - newSection: { - someKey: '某个值' // extra section - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'test.ts'), enContent) - fs.writeFileSync(path.join(testZhDir, 'test.ts'), zhContent) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(extraKeys).toContain('test.common.delete') - expect(extraKeys).toContain('test.common.extra') - expect(extraKeys).toContain('test.newSection.someKey') - expect(extraKeys).toHaveLength(3) - }) - }) - - describe('File filtering logic', () => { - it('should filter keys by specific file correctly', async () => { - // Create multiple files - const file1Content = `const translation = { - button: 'Button', - text: 'Text' -} - -export default translation -` - - const file2Content = `const translation = { - title: 'Title', - description: 'Description' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'components.ts'), file1Content) - fs.writeFileSync(path.join(testEnDir, 'pages.ts'), file2Content) - fs.writeFileSync(path.join(testZhDir, 'components.ts'), file1Content) - fs.writeFileSync(path.join(testZhDir, 'pages.ts'), file2Content) - - const allEnKeys = await getKeysFromLanguage('en-US') - - // Test file filtering logic - const targetFile = 'components' - const filteredEnKeys = allEnKeys.filter(key => - key.startsWith(targetFile.replace(/[-_](.)/g, (_, c) => c.toUpperCase())), - ) - - expect(allEnKeys).toHaveLength(4) // 2 keys from each file - expect(filteredEnKeys).toHaveLength(2) // only components keys - expect(filteredEnKeys).toContain('components.button') - expect(filteredEnKeys).toContain('components.text') - expect(filteredEnKeys).not.toContain('pages.title') - expect(filteredEnKeys).not.toContain('pages.description') - }) - }) - - describe('Complex nested structure handling', () => { - it('should handle deeply nested objects correctly', async () => { - const complexContent = `const translation = { - level1: { - level2: { - level3: { - level4: { - deepValue: 'Deep Value' - }, - anotherValue: 'Another Value' - }, - simpleValue: 'Simple Value' - }, - directValue: 'Direct Value' - }, - rootValue: 'Root Value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'complex.ts'), complexContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('complex.level1.level2.level3.level4.deepValue') - expect(keys).toContain('complex.level1.level2.level3.anotherValue') - expect(keys).toContain('complex.level1.level2.simpleValue') - expect(keys).toContain('complex.level1.directValue') - expect(keys).toContain('complex.rootValue') - - // Should not include intermediate objects - expect(keys).not.toContain('complex.level1') - expect(keys).not.toContain('complex.level1.level2') - expect(keys).not.toContain('complex.level1.level2.level3') - expect(keys).not.toContain('complex.level1.level2.level3.level4') - }) - }) - - describe('Edge cases', () => { - it('should handle empty objects', async () => { - const emptyContent = `const translation = { - empty: {}, - withValue: 'value' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'empty.ts'), emptyContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('empty.withValue') - expect(keys).not.toContain('empty.empty') - }) - - it('should handle special characters in keys', async () => { - const specialContent = `const translation = { - 'key-with-dash': 'value1', - 'key_with_underscore': 'value2', - 'key.with.dots': 'value3', - normalKey: 'value4' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'special.ts'), specialContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('special.key-with-dash') - expect(keys).toContain('special.key_with_underscore') - expect(keys).toContain('special.key.with.dots') - expect(keys).toContain('special.normalKey') - }) - - it('should handle different value types', async () => { - const typesContent = `const translation = { - stringValue: 'string', - numberValue: 42, - booleanValue: true, - nullValue: null, - undefinedValue: undefined, - arrayValue: ['array', 'values'], - objectValue: { - nested: 'nested value' - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'types.ts'), typesContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('types.stringValue') - expect(keys).toContain('types.numberValue') - expect(keys).toContain('types.booleanValue') - expect(keys).toContain('types.nullValue') - expect(keys).toContain('types.undefinedValue') - expect(keys).toContain('types.arrayValue') - expect(keys).toContain('types.objectValue.nested') - expect(keys).not.toContain('types.objectValue') - }) - }) - - describe('Real-world scenario tests', () => { - it('should handle app-debug structure like real files', async () => { - const appDebugEn = `const translation = { - pageTitle: { - line1: 'Prompt', - line2: 'Engineering' - }, - operation: { - applyConfig: 'Publish', - resetConfig: 'Reset', - debugConfig: 'Debug' - }, - generate: { - instruction: 'Instructions', - generate: 'Generate', - resTitle: 'Generated Prompt', - noDataLine1: 'Describe your use case on the left,', - noDataLine2: 'the orchestration preview will show here.' - } -} - -export default translation -` - - const appDebugZh = `const translation = { - pageTitle: { - line1: '提示词', - line2: '编排' - }, - operation: { - applyConfig: '发布', - resetConfig: '重置', - debugConfig: '调试' - }, - generate: { - instruction: '指令', - generate: '生成', - resTitle: '生成的提示词', - noData: '在左侧描述您的用例,编排预览将在此处显示。' // This is extra - } -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'app-debug.ts'), appDebugEn) - fs.writeFileSync(path.join(testZhDir, 'app-debug.ts'), appDebugZh) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(missingKeys).toContain('appDebug.generate.noDataLine1') - expect(missingKeys).toContain('appDebug.generate.noDataLine2') - expect(extraKeys).toContain('appDebug.generate.noData') - - expect(missingKeys).toHaveLength(2) - expect(extraKeys).toHaveLength(1) - }) - - it('should handle time structure with operation nested keys', async () => { - const timeEn = `const translation = { - months: { - January: 'January', - February: 'February' - }, - operation: { - now: 'Now', - ok: 'OK', - cancel: 'Cancel', - pickDate: 'Pick Date' - }, - title: { - pickTime: 'Pick Time' - }, - defaultPlaceholder: 'Pick a time...' -} - -export default translation -` - - const timeZh = `const translation = { - months: { - January: '一月', - February: '二月' - }, - operation: { - now: '此刻', - ok: '确定', - cancel: '取消', - pickDate: '选择日期' - }, - title: { - pickTime: '选择时间' - }, - pickDate: '选择日期', // This is extra - duplicates operation.pickDate - defaultPlaceholder: '请选择时间...' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'time.ts'), timeEn) - fs.writeFileSync(path.join(testZhDir, 'time.ts'), timeZh) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeys = await getKeysFromLanguage('zh-Hans') - - const missingKeys = enKeys.filter(key => !zhKeys.includes(key)) - const extraKeys = zhKeys.filter(key => !enKeys.includes(key)) - - expect(missingKeys).toHaveLength(0) // No missing keys - expect(extraKeys).toContain('time.pickDate') // Extra root-level pickDate - expect(extraKeys).toHaveLength(1) - - // Should have both keys available - expect(zhKeys).toContain('time.operation.pickDate') // Correct nested key - expect(zhKeys).toContain('time.pickDate') // Extra duplicate key - }) - }) - - describe('Statistics calculation', () => { - it('should calculate correct difference statistics', async () => { - const enContent = `const translation = { - key1: 'value1', - key2: 'value2', - key3: 'value3' -} - -export default translation -` - - const zhContentMissing = `const translation = { - key1: 'value1', - key2: 'value2' - // missing key3 -} - -export default translation -` - - const zhContentExtra = `const translation = { - key1: 'value1', - key2: 'value2', - key3: 'value3', - key4: 'extra', - key5: 'extra2' -} - -export default translation -` - - fs.writeFileSync(path.join(testEnDir, 'stats.ts'), enContent) - - // Test missing keys scenario - fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentMissing) - - const enKeys = await getKeysFromLanguage('en-US') - const zhKeysMissing = await getKeysFromLanguage('zh-Hans') - - expect(enKeys.length - zhKeysMissing.length).toBe(1) // +1 means 1 missing key - - // Test extra keys scenario - fs.writeFileSync(path.join(testZhDir, 'stats.ts'), zhContentExtra) - - const zhKeysExtra = await getKeysFromLanguage('zh-Hans') - - expect(enKeys.length - zhKeysExtra.length).toBe(-2) // -2 means 2 extra keys - }) - }) - - describe('Auto-remove multiline key-value pairs', () => { - // Helper function to simulate removeExtraKeysFromFile logic - function removeExtraKeysFromFile(content: string, keysToRemove: string[]): string { - const lines = content.split('\n') - const linesToRemove: number[] = [] - - for (const keyToRemove of keysToRemove) { - let targetLineIndex = -1 - const linesToRemoveForKey: number[] = [] - - // Find the key line (simplified for single-level keys in test) - for (let i = 0; i < lines.length; i++) { - const line = lines[i] - const keyPattern = new RegExp(`^\\s*${keyToRemove}\\s*:`) - if (keyPattern.test(line)) { - targetLineIndex = i - break - } - } - - if (targetLineIndex !== -1) { - linesToRemoveForKey.push(targetLineIndex) - - // Check if this is a multiline key-value pair - const keyLine = lines[targetLineIndex] - const trimmedKeyLine = keyLine.trim() - - // If key line ends with ":" (not complete value), it's likely multiline - if (trimmedKeyLine.endsWith(':') && !trimmedKeyLine.includes('{') && !/:\s*['"`]/.exec(trimmedKeyLine)) { - // Find the value lines that belong to this key - let currentLine = targetLineIndex + 1 - let foundValue = false - - while (currentLine < lines.length) { - const line = lines[currentLine] - const trimmed = line.trim() - - // Skip empty lines - if (trimmed === '') { - currentLine++ - continue - } - - // Check if this line starts a new key (indicates end of current value) - if (/^\w+\s*:/.exec(trimmed)) - break - - // Check if this line is part of the value - if (trimmed.startsWith('\'') || trimmed.startsWith('"') || trimmed.startsWith('`') || foundValue) { - linesToRemoveForKey.push(currentLine) - foundValue = true - - // Check if this line ends the value (ends with quote and comma/no comma) - if ((trimmed.endsWith('\',') || trimmed.endsWith('",') || trimmed.endsWith('`,') - || trimmed.endsWith('\'') || trimmed.endsWith('"') || trimmed.endsWith('`')) - && !trimmed.startsWith('//')) { - break - } - } - else { - break - } - - currentLine++ - } - } - - linesToRemove.push(...linesToRemoveForKey) - } - } - - // Remove duplicates and sort in reverse order - const uniqueLinesToRemove = [...new Set(linesToRemove)].sort((a, b) => b - a) - - for (const lineIndex of uniqueLinesToRemove) - lines.splice(lineIndex, 1) - - return lines.join('\n') - } - - it('should remove single-line key-value pairs correctly', () => { - const content = `const translation = { - keepThis: 'This should stay', - removeThis: 'This should be removed', - alsoKeep: 'This should also stay', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeThis']) - - expect(result).toContain('keepThis: \'This should stay\'') - expect(result).toContain('alsoKeep: \'This should also stay\'') - expect(result).not.toContain('removeThis: \'This should be removed\'') - }) - - it('should remove multiline key-value pairs completely', () => { - const content = `const translation = { - keepThis: 'This should stay', - removeMultiline: - 'This is a multiline value that should be removed completely', - alsoKeep: 'This should also stay', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeMultiline']) - - expect(result).toContain('keepThis: \'This should stay\'') - expect(result).toContain('alsoKeep: \'This should also stay\'') - expect(result).not.toContain('removeMultiline:') - expect(result).not.toContain('This is a multiline value that should be removed completely') - }) - - it('should handle mixed single-line and multiline removals', () => { - const content = `const translation = { - keepThis: 'Keep this', - removeSingle: 'Remove this single line', - removeMultiline: - 'Remove this multiline value', - anotherMultiline: - 'Another multiline that spans multiple lines', - keepAnother: 'Keep this too', -} - -export default translation` - - const result = removeExtraKeysFromFile(content, ['removeSingle', 'removeMultiline', 'anotherMultiline']) - - expect(result).toContain('keepThis: \'Keep this\'') - expect(result).toContain('keepAnother: \'Keep this too\'') - expect(result).not.toContain('removeSingle:') - expect(result).not.toContain('removeMultiline:') - expect(result).not.toContain('anotherMultiline:') - expect(result).not.toContain('Remove this single line') - expect(result).not.toContain('Remove this multiline value') - expect(result).not.toContain('Another multiline that spans multiple lines') - }) - - it('should properly detect multiline vs single-line patterns', () => { - const multilineContent = `const translation = { - singleLine: 'This is single line', - multilineKey: - 'This is multiline', - keyWithColon: 'Value with: colon inside', - objectKey: { - nested: 'value' - }, -} - -export default translation` - - // Test that single line with colon in value is not treated as multiline - const result1 = removeExtraKeysFromFile(multilineContent, ['keyWithColon']) - expect(result1).not.toContain('keyWithColon:') - expect(result1).not.toContain('Value with: colon inside') - - // Test that true multiline is handled correctly - const result2 = removeExtraKeysFromFile(multilineContent, ['multilineKey']) - expect(result2).not.toContain('multilineKey:') - expect(result2).not.toContain('This is multiline') - - // Test that object key removal works (note: this is a simplified test) - // In real scenario, object removal would be more complex - const result3 = removeExtraKeysFromFile(multilineContent, ['objectKey']) - expect(result3).not.toContain('objectKey: {') - // Note: Our simplified test function doesn't handle nested object removal perfectly - // This is acceptable as it's testing the main multiline string removal functionality - }) - - it('should handle real-world Polish translation structure', () => { - const polishContent = `const translation = { - createApp: 'UTWÓRZ APLIKACJĘ', - newApp: { - captionAppType: 'Jaki typ aplikacji chcesz stworzyć?', - chatbotDescription: - 'Zbuduj aplikację opartą na czacie. Ta aplikacja używa formatu pytań i odpowiedzi.', - agentDescription: - 'Zbuduj inteligentnego agenta, który może autonomicznie wybierać narzędzia.', - basic: 'Podstawowy', - }, -} - -export default translation` - - const result = removeExtraKeysFromFile(polishContent, ['captionAppType', 'chatbotDescription', 'agentDescription']) - - expect(result).toContain('createApp: \'UTWÓRZ APLIKACJĘ\'') - expect(result).toContain('basic: \'Podstawowy\'') - expect(result).not.toContain('captionAppType:') - expect(result).not.toContain('chatbotDescription:') - expect(result).not.toContain('agentDescription:') - expect(result).not.toContain('Jaki typ aplikacji') - expect(result).not.toContain('Zbuduj aplikację opartą na czacie') - expect(result).not.toContain('Zbuduj inteligentnego agenta') - }) - }) - - describe('Performance and Scalability', () => { - it('should handle large translation files efficiently', async () => { - // Create a large translation file with 1000 keys - const largeContent = `const translation = { -${Array.from({ length: 1000 }, (_, i) => ` key${i}: 'value${i}',`).join('\n')} -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'large.ts'), largeContent) - - const startTime = Date.now() - const keys = await getKeysFromLanguage('en-US') - const endTime = Date.now() - - expect(keys.length).toBe(1000) - expect(endTime - startTime).toBeLessThan(10000) - }) - - it('should handle multiple translation files concurrently', async () => { - // Create multiple files - for (let i = 0; i < 10; i++) { - const content = `const translation = { - key${i}: 'value${i}', - nested${i}: { - subkey: 'subvalue' - } -} - -export default translation` - fs.writeFileSync(path.join(testEnDir, `file${i}.ts`), content) - } - - const startTime = Date.now() - const keys = await getKeysFromLanguage('en-US') - const endTime = Date.now() - - expect(keys.length).toBe(20) // 10 files * 2 keys each - expect(endTime - startTime).toBeLessThan(10000) - }) - }) - - describe('Unicode and Internationalization', () => { - it('should handle Unicode characters in keys and values', async () => { - const unicodeContent = `const translation = { - '中文键': '中文值', - 'العربية': 'قيمة', - 'emoji_😀': 'value with emoji 🎉', - 'mixed_中文_English': 'mixed value' -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'unicode.ts'), unicodeContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('unicode.中文键') - expect(keys).toContain('unicode.العربية') - expect(keys).toContain('unicode.emoji_😀') - expect(keys).toContain('unicode.mixed_中文_English') - }) - - it('should handle RTL language files', async () => { - const rtlContent = `const translation = { - مرحبا: 'Hello', - العالم: 'World', - nested: { - مفتاح: 'key' - } -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'rtl.ts'), rtlContent) - - const keys = await getKeysFromLanguage('en-US') - - expect(keys).toContain('rtl.مرحبا') - expect(keys).toContain('rtl.العالم') - expect(keys).toContain('rtl.nested.مفتاح') - }) - }) - - describe('Error Recovery', () => { - it('should handle syntax errors in translation files gracefully', async () => { - const invalidContent = `const translation = { - validKey: 'valid value', - invalidKey: 'missing quote, - anotherKey: 'another value' -} - -export default translation` - - fs.writeFileSync(path.join(testEnDir, 'invalid.ts'), invalidContent) - - await expect(getKeysFromLanguage('en-US')).rejects.toThrow() - }) - }) -}) diff --git a/web/__tests__/plugins/plugin-install-flow.test.ts b/web/__tests__/plugins/plugin-install-flow.test.ts index 8fa2246198..dd5a18b724 100644 --- a/web/__tests__/plugins/plugin-install-flow.test.ts +++ b/web/__tests__/plugins/plugin-install-flow.test.ts @@ -5,11 +5,9 @@ * upload handling, and task status polling. Verifies the complete plugin * installation pipeline from source discovery to completion. */ -import { beforeEach, describe, expect, it, vi } from 'vitest' -vi.mock('@/config', () => ({ - GITHUB_ACCESS_TOKEN: '', -})) +import { beforeEach, describe, expect, it, vi } from 'vitest' +import { checkForUpdates, fetchReleases, handleUpload } from '@/app/components/plugins/install-plugin/hooks' const mockToastNotify = vi.fn() vi.mock('@/app/components/base/ui/toast', () => ({ @@ -30,10 +28,6 @@ vi.mock('@/service/plugins', () => ({ checkTaskStatus: vi.fn(), })) -const { useGitHubReleases, useGitHubUpload } = await import( - '@/app/components/plugins/install-plugin/hooks', -) - describe('Plugin Installation Flow Integration', () => { beforeEach(() => { vi.clearAllMocks() @@ -44,22 +38,22 @@ describe('Plugin Installation Flow Integration', () => { it('fetches releases, checks for updates, and uploads the new version', async () => { const mockReleases = [ { - tag_name: 'v2.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v2.difypkg', name: 'plugin-v2.difypkg' }], + tag: 'v2.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v2.difypkg' }], }, { - tag_name: 'v1.5.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.5.difypkg', name: 'plugin-v1.5.difypkg' }], + tag: 'v1.5.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.5.difypkg' }], }, { - tag_name: 'v1.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.difypkg', name: 'plugin-v1.difypkg' }], + tag: 'v1.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.difypkg' }], }, ] ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve(mockReleases), + json: () => Promise.resolve({ releases: mockReleases }), }) mockUploadGitHub.mockResolvedValue({ @@ -67,8 +61,6 @@ describe('Plugin Installation Flow Integration', () => { unique_identifier: 'test-plugin:2.0.0', }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') expect(releases).toHaveLength(3) expect(releases[0].tag_name).toBe('v2.0.0') @@ -77,7 +69,6 @@ describe('Plugin Installation Flow Integration', () => { expect(needUpdate).toBe(true) expect(toastProps.message).toContain('v2.0.0') - const { handleUpload } = useGitHubUpload() const onSuccess = vi.fn() const result = await handleUpload( 'https://github.com/test-org/test-repo', @@ -104,18 +95,16 @@ describe('Plugin Installation Flow Integration', () => { it('handles no new version available', async () => { const mockReleases = [ { - tag_name: 'v1.0.0', - assets: [{ browser_download_url: 'https://github.com/test/v1.difypkg', name: 'plugin-v1.difypkg' }], + tag: 'v1.0.0', + assets: [{ downloadUrl: 'https://github.com/test/v1.difypkg' }], }, ] ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve(mockReleases), + json: () => Promise.resolve({ releases: mockReleases }), }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') const { needUpdate, toastProps } = checkForUpdates(releases, 'v1.0.0') @@ -127,11 +116,9 @@ describe('Plugin Installation Flow Integration', () => { it('handles empty releases', async () => { ;(globalThis.fetch as ReturnType).mockResolvedValue({ ok: true, - json: () => Promise.resolve([]), + json: () => Promise.resolve({ releases: [] }), }) - const { fetchReleases, checkForUpdates } = useGitHubReleases() - const releases = await fetchReleases('test-org', 'test-repo') expect(releases).toHaveLength(0) @@ -147,7 +134,6 @@ describe('Plugin Installation Flow Integration', () => { status: 404, }) - const { fetchReleases } = useGitHubReleases() const releases = await fetchReleases('nonexistent-org', 'nonexistent-repo') expect(releases).toEqual([]) @@ -159,7 +145,6 @@ describe('Plugin Installation Flow Integration', () => { it('handles upload failure gracefully', async () => { mockUploadGitHub.mockRejectedValue(new Error('Upload failed')) - const { handleUpload } = useGitHubUpload() const onSuccess = vi.fn() await expect( diff --git a/web/app/components/app/app-access-control/access-control.spec.tsx b/web/app/components/app/app-access-control/access-control.spec.tsx index 3950bdf7ee..3a5f2272ed 100644 --- a/web/app/components/app/app-access-control/access-control.spec.tsx +++ b/web/app/components/app/app-access-control/access-control.spec.tsx @@ -109,7 +109,7 @@ beforeAll(() => { disconnect = vi.fn(() => undefined) unobserve = vi.fn(() => undefined) } - // @ts-expect-error jsdom does not implement IntersectionObserver + // @ts-expect-error test DOM typings do not guarantee IntersectionObserver here globalThis.IntersectionObserver = MockIntersectionObserver }) diff --git a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx index 188086246a..389ab189e9 100644 --- a/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx +++ b/web/app/components/app/configuration/debug/debug-with-multiple-model/index.spec.tsx @@ -556,8 +556,8 @@ describe('DebugWithMultipleModel', () => { ) const twoItems = screen.getAllByTestId('debug-item') - expect(twoItems[0].style.width).toBe('calc(50% - 28px)') - expect(twoItems[1].style.width).toBe('calc(50% - 28px)') + expect(twoItems[0].style.width).toBe('calc(50% - 4px - 24px)') + expect(twoItems[1].style.width).toBe('calc(50% - 4px - 24px)') }) }) @@ -596,13 +596,13 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(2) expectItemLayout(items[0], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: [], @@ -620,19 +620,19 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(3) expectItemLayout(items[0], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(0) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[1], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mr-2'], }) expectItemLayout(items[2], { - width: 'calc(33.3% - 21.33px)', + width: 'calc(33.3% - 5.33px - 16px)', height: '100%', transform: 'translateX(calc(200% + 16px)) translateY(0)', classes: [], @@ -655,25 +655,25 @@ describe('DebugWithMultipleModel', () => { // Assert expect(items).toHaveLength(4) expectItemLayout(items[0], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(0)', classes: ['mr-2', 'mb-2'], }) expectItemLayout(items[1], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(0)', classes: ['mb-2'], }) expectItemLayout(items[2], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(0) translateY(calc(100% + 8px))', classes: ['mr-2'], }) expectItemLayout(items[3], { - width: 'calc(50% - 28px)', + width: 'calc(50% - 4px - 24px)', height: 'calc(50% - 4px)', transform: 'translateX(calc(100% + 8px)) translateY(calc(100% + 8px))', classes: [], diff --git a/web/app/components/app/overview/settings/index.spec.tsx b/web/app/components/app/overview/settings/index.spec.tsx index b849b4f015..e933855ca8 100644 --- a/web/app/components/app/overview/settings/index.spec.tsx +++ b/web/app/components/app/overview/settings/index.spec.tsx @@ -1,6 +1,3 @@ -/** - * @vitest-environment jsdom - */ import type { ReactNode } from 'react' import type { ModalContextState } from '@/context/modal-context' import type { ProviderContextState } from '@/context/provider-context' diff --git a/web/app/components/base/action-button/__tests__/index.spec.tsx b/web/app/components/base/action-button/__tests__/index.spec.tsx index 949a980272..e9db157d0c 100644 --- a/web/app/components/base/action-button/__tests__/index.spec.tsx +++ b/web/app/components/base/action-button/__tests__/index.spec.tsx @@ -62,8 +62,8 @@ describe('ActionButton', () => { ) const button = screen.getByRole('button', { name: 'Custom Style' }) expect(button).toHaveStyle({ - color: 'rgb(255, 0, 0)', - backgroundColor: 'rgb(0, 0, 255)', + color: 'red', + backgroundColor: 'blue', }) }) diff --git a/web/app/components/base/chat/chat/__tests__/index.spec.tsx b/web/app/components/base/chat/chat/__tests__/index.spec.tsx index 781b5e86f3..0100b059f0 100644 --- a/web/app/components/base/chat/chat/__tests__/index.spec.tsx +++ b/web/app/components/base/chat/chat/__tests__/index.spec.tsx @@ -8,10 +8,10 @@ import Chat from '../index' // ─── Why each mock exists ───────────────────────────────────────────────────── // // Answer – transitively pulls Markdown (rehype/remark/katex), AgentContent, -// WorkflowProcessItem and Operation; none can resolve in jsdom. +// WorkflowProcessItem and Operation; none can resolve in the test DOM runtime. // Question – pulls Markdown, copy-to-clipboard, react-textarea-autosize. // ChatInputArea – pulls js-audio-recorder (requires Web Audio API unavailable in -// jsdom) and VoiceInput / FileContextProvider chains. +// the test DOM runtime) and VoiceInput / FileContextProvider chains. // PromptLogModal– pulls CopyFeedbackNew and deep modal dep chain. // AgentLogModal – pulls @remixicon/react (causes lint push error), useClickAway // from ahooks, and AgentLogDetail (workflow graph renderer). diff --git a/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx b/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx index d8e00780b1..8839798c15 100644 --- a/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx +++ b/web/app/components/base/date-and-time-picker/calendar/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen } from '@testing-library/react' import dayjs from '../../utils/dayjs' import Calendar from '../index' -// Mock scrollIntoView since jsdom doesn't implement it +// Mock scrollIntoView since the test DOM runtime doesn't implement it beforeAll(() => { Element.prototype.scrollIntoView = vi.fn() }) diff --git a/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx b/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx index 910faf9cd4..199ed4ee41 100644 --- a/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx +++ b/web/app/components/base/date-and-time-picker/time-picker/__tests__/index.spec.tsx @@ -3,7 +3,7 @@ import { fireEvent, render, screen, within } from '@testing-library/react' import dayjs, { isDayjsObject } from '../../utils/dayjs' import TimePicker from '../index' -// Mock scrollIntoView since jsdom doesn't implement it +// Mock scrollIntoView since the test DOM runtime doesn't implement it beforeAll(() => { Element.prototype.scrollIntoView = vi.fn() }) diff --git a/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx b/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx index 69496903a6..de9cc7ecd0 100644 --- a/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx +++ b/web/app/components/base/file-uploader/file-from-link-or-local/index.tsx @@ -37,11 +37,11 @@ const FileFromLinkOrLocal = ({ const { handleLoadFileFromLink } = useFile(fileConfig) const disabled = !!fileConfig.number_limits && files.length >= fileConfig.number_limits const fileLinkPlaceholder = t('fileUploader.pasteFileLinkInputPlaceholder', { ns: 'common' }) - /* v8 ignore next -- fallback for missing i18n key is not reliably testable under current global translation mocks in jsdom @preserve */ + /* v8 ignore next -- fallback for a missing i18n key is not reliably testable under the current global translation mocks in the test DOM runtime. @preserve */ const fileLinkPlaceholderText = fileLinkPlaceholder || '' const handleSaveUrl = () => { - /* v8 ignore next -- guarded by UI-level disabled state (`disabled={!url || disabled}`), not reachable in jsdom click flow @preserve */ + /* v8 ignore next -- guarded by UI-level disabled state (`disabled={!url || disabled}`), not reachable in the current test click flow. @preserve */ if (!url) return diff --git a/web/app/components/base/icons/__tests__/utils.spec.ts b/web/app/components/base/icons/__tests__/utils.spec.ts index a25f39111d..f8534038bf 100644 --- a/web/app/components/base/icons/__tests__/utils.spec.ts +++ b/web/app/components/base/icons/__tests__/utils.spec.ts @@ -62,7 +62,7 @@ describe('generate icon base utils', () => { const { container } = render(generate(node, 'key')) // to svg element expect(container.firstChild).toHaveClass('container') - expect(container.querySelector('span')).toHaveStyle({ color: 'rgb(0, 0, 255)' }) + expect(container.querySelector('span')).toHaveStyle({ color: 'blue' }) }) // add not has children diff --git a/web/app/components/base/input/__tests__/index.spec.tsx b/web/app/components/base/input/__tests__/index.spec.tsx index 2c5b563a12..dfab8617c2 100644 --- a/web/app/components/base/input/__tests__/index.spec.tsx +++ b/web/app/components/base/input/__tests__/index.spec.tsx @@ -99,7 +99,7 @@ describe('Input component', () => { render() const input = screen.getByPlaceholderText(/input/i) expect(input).toHaveClass(customClass) - expect(input).toHaveStyle({ color: 'rgb(255, 0, 0)' }) + expect(input).toHaveStyle({ color: 'red' }) }) it('applies large size variant correctly', () => { diff --git a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx index 745b7657d7..a16686801c 100644 --- a/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx +++ b/web/app/components/base/markdown-blocks/__tests__/code-block.spec.tsx @@ -1,6 +1,6 @@ -import { createRequire } from 'node:module' import { act, render, screen, waitFor } from '@testing-library/react' import userEvent from '@testing-library/user-event' +import * as echarts from 'echarts' import { Theme } from '@/types/app' import CodeBlock from '../code-block' @@ -10,12 +10,21 @@ type UseThemeReturn = { } const mockUseTheme = vi.fn<() => UseThemeReturn>(() => ({ theme: Theme.light })) -const require = createRequire(import.meta.url) -const echartsCjs = require('echarts') as { - getInstanceByDom: (dom: HTMLDivElement | null) => { - resize: (opts?: { width?: string, height?: string }) => void - } | null -} +const mockEcharts = vi.hoisted(() => { + const state = { + finishedHandler: undefined as undefined | ((event?: unknown) => void), + echartsInstance: { + resize: vi.fn<(opts?: { width?: string, height?: string }) => void>(), + trigger: vi.fn((eventName: string, event?: unknown) => { + if (eventName === 'finished') + state.finishedHandler?.(event) + }), + }, + getInstanceByDom: vi.fn(() => state.echartsInstance), + } + + return state +}) let clientWidthSpy: { mockRestore: () => void } | null = null let clientHeightSpy: { mockRestore: () => void } | null = null @@ -61,6 +70,42 @@ vi.mock('@/hooks/use-theme', () => ({ default: () => mockUseTheme(), })) +vi.mock('echarts', () => ({ + getInstanceByDom: mockEcharts.getInstanceByDom, +})) + +vi.mock('echarts-for-react', async () => { + const React = await vi.importActual('react') + + const MockReactEcharts = React.forwardRef(({ + onChartReady, + onEvents, + }: { + onChartReady?: (instance: typeof mockEcharts.echartsInstance) => void + onEvents?: { finished?: (event?: unknown) => void } + }, ref: React.ForwardedRef<{ getEchartsInstance: () => typeof mockEcharts.echartsInstance }>) => { + React.useImperativeHandle(ref, () => ({ + getEchartsInstance: () => mockEcharts.echartsInstance, + })) + + React.useEffect(() => { + mockEcharts.finishedHandler = onEvents?.finished + onChartReady?.(mockEcharts.echartsInstance) + onEvents?.finished?.({}) + return () => { + mockEcharts.finishedHandler = undefined + } + }, [onChartReady, onEvents]) + + return
+ }) + + return { + __esModule: true, + default: MockReactEcharts, + } +}) + vi.mock('@/app/components/base/mermaid', () => ({ __esModule: true, default: ({ PrimitiveCode }: { PrimitiveCode: string }) =>
{PrimitiveCode}
, @@ -76,9 +121,9 @@ const findEchartsHost = async () => { const findEchartsInstance = async () => { const host = await findEchartsHost() await waitFor(() => { - expect(echartsCjs.getInstanceByDom(host)).toBeTruthy() + expect(echarts.getInstanceByDom(host)).toBeTruthy() }) - return echartsCjs.getInstanceByDom(host)! + return echarts.getInstanceByDom(host)! } describe('CodeBlock', () => { diff --git a/web/app/components/base/node-status/__tests__/index.spec.tsx b/web/app/components/base/node-status/__tests__/index.spec.tsx index f74af4965e..37b12946c8 100644 --- a/web/app/components/base/node-status/__tests__/index.spec.tsx +++ b/web/app/components/base/node-status/__tests__/index.spec.tsx @@ -41,7 +41,7 @@ describe('NodeStatus', () => { it('applies styleCss correctly', () => { const { container } = render() - expect(container.firstChild).toHaveStyle({ color: 'rgb(255, 0, 0)' }) + expect(container.firstChild).toHaveStyle({ color: 'red' }) }) it('applies iconClassName to the icon', () => { diff --git a/web/app/components/base/pagination/__tests__/pagination.spec.tsx b/web/app/components/base/pagination/__tests__/pagination.spec.tsx index 776802ff19..06eac9bfbd 100644 --- a/web/app/components/base/pagination/__tests__/pagination.spec.tsx +++ b/web/app/components/base/pagination/__tests__/pagination.spec.tsx @@ -131,7 +131,7 @@ describe('Pagination', () => { setCurrentPage, children: Prev, }) - fireEvent.keyPress(screen.getByText(/prev/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/prev/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(2) }) @@ -142,7 +142,7 @@ describe('Pagination', () => { setCurrentPage, children: Prev, }) - fireEvent.keyPress(screen.getByText(/prev/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/prev/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).not.toHaveBeenCalled() }) @@ -213,7 +213,7 @@ describe('Pagination', () => { setCurrentPage, children: Next, }) - fireEvent.keyPress(screen.getByText(/next/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/next/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(1) }) @@ -225,7 +225,7 @@ describe('Pagination', () => { setCurrentPage, children: Next, }) - fireEvent.keyPress(screen.getByText(/next/i), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText(/next/i).closest('button')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).not.toHaveBeenCalled() }) @@ -318,7 +318,7 @@ describe('Pagination', () => { /> ), }) - fireEvent.keyPress(screen.getByText('4'), { key: 'Enter', charCode: 13 }) + fireEvent.keyDown(screen.getByText('4').closest('a')!, { key: 'Enter', code: 'Enter', keyCode: 13, which: 13 }) expect(setCurrentPage).toHaveBeenCalledWith(3) // 0-indexed }) diff --git a/web/app/components/base/pagination/pagination.tsx b/web/app/components/base/pagination/pagination.tsx index 0eb06b594c..b258090d80 100644 --- a/web/app/components/base/pagination/pagination.tsx +++ b/web/app/components/base/pagination/pagination.tsx @@ -50,7 +50,7 @@ export const PrevButton = ({ tabIndex={disabled ? '-1' : 0} disabled={disabled} data-testid={dataTestId} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { event.preventDefault() if (event.key === 'Enter' && !disabled) previous() @@ -85,7 +85,7 @@ export const NextButton = ({ tabIndex={disabled ? '-1' : 0} disabled={disabled} data-testid={dataTestId} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { event.preventDefault() if (event.key === 'Enter' && !disabled) next() @@ -140,7 +140,7 @@ export const PageButton = ({ }) || undefined } tabIndex={0} - onKeyPress={(event: React.KeyboardEvent) => { + onKeyDown={(event: React.KeyboardEvent) => { if (event.key === 'Enter') pagination.setCurrentPage(page - 1) }} diff --git a/web/app/components/base/premium-badge/__tests__/index.spec.tsx b/web/app/components/base/premium-badge/__tests__/index.spec.tsx index af8ace22f0..d107c07e52 100644 --- a/web/app/components/base/premium-badge/__tests__/index.spec.tsx +++ b/web/app/components/base/premium-badge/__tests__/index.spec.tsx @@ -41,6 +41,6 @@ describe('PremiumBadge', () => { ) const badge = screen.getByText('Premium') expect(badge).toBeInTheDocument() - expect(badge).toHaveStyle('background-color: rgb(255, 0, 0)') // Note: React converts 'red' to 'rgb(255, 0, 0)' + expect(badge).toHaveStyle('background-color: red') }) }) diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx index dd2f74f7e5..a16ae9d823 100644 --- a/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/__tests__/on-blur-or-focus-block.spec.tsx @@ -3,13 +3,10 @@ import { LexicalComposer } from '@lexical/react/LexicalComposer' import { act, render, waitFor } from '@testing-library/react' import { BLUR_COMMAND, - COMMAND_PRIORITY_EDITOR, FOCUS_COMMAND, - KEY_ESCAPE_COMMAND, } from 'lexical' import OnBlurBlock from '../on-blur-or-focus-block' import { CaptureEditorPlugin } from '../test-utils' -import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block' const renderOnBlurBlock = (props?: { onBlur?: () => void @@ -75,7 +72,7 @@ describe('OnBlurBlock', () => { expect(onFocus).toHaveBeenCalledTimes(1) }) - it('should call onBlur and dispatch escape after delay when blur target is not var-search-input', async () => { + it('should call onBlur when blur target is not var-search-input', async () => { const onBlur = vi.fn() const { getEditor } = renderOnBlurBlock({ onBlur }) @@ -85,14 +82,6 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) let handled = false act(() => { @@ -101,18 +90,9 @@ describe('OnBlurBlock', () => { expect(handled).toBe(true) expect(onBlur).toHaveBeenCalledTimes(1) - expect(onEscape).not.toHaveBeenCalled() - - act(() => { - vi.advanceTimersByTime(200) - }) - - expect(onEscape).toHaveBeenCalledTimes(1) - unregister() - vi.useRealTimers() }) - it('should dispatch delayed escape when onBlur callback is not provided', async () => { + it('should handle blur when onBlur callback is not provided', async () => { const { getEditor } = renderOnBlurBlock() await waitFor(() => { @@ -121,28 +101,16 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) + let handled = false act(() => { - editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) - }) - act(() => { - vi.advanceTimersByTime(200) + handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) }) - expect(onEscape).toHaveBeenCalledTimes(1) - unregister() - vi.useRealTimers() + expect(handled).toBe(true) }) - it('should skip onBlur and delayed escape when blur target is var-search-input', async () => { + it('should skip onBlur when blur target is var-search-input', async () => { const onBlur = vi.fn() const { getEditor } = renderOnBlurBlock({ onBlur }) @@ -152,31 +120,17 @@ describe('OnBlurBlock', () => { const editor = getEditor() expect(editor).not.toBeNull() - vi.useFakeTimers() const target = document.createElement('input') target.classList.add('var-search-input') - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) - let handled = false act(() => { handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(target)) }) - act(() => { - vi.advanceTimersByTime(200) - }) expect(handled).toBe(true) expect(onBlur).not.toHaveBeenCalled() - expect(onEscape).not.toHaveBeenCalled() - unregister() - vi.useRealTimers() }) it('should handle focus command when onFocus callback is not provided', async () => { @@ -198,59 +152,6 @@ describe('OnBlurBlock', () => { }) }) - describe('Clear timeout command', () => { - it('should clear scheduled escape timeout when clear command is dispatched', async () => { - const { getEditor } = renderOnBlurBlock({ onBlur: vi.fn() }) - - await waitFor(() => { - expect(getEditor()).not.toBeNull() - }) - - const editor = getEditor() - expect(editor).not.toBeNull() - vi.useFakeTimers() - - const onEscape = vi.fn(() => true) - const unregister = editor!.registerCommand( - KEY_ESCAPE_COMMAND, - onEscape, - COMMAND_PRIORITY_EDITOR, - ) - - act(() => { - editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) - }) - act(() => { - editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - }) - act(() => { - vi.advanceTimersByTime(200) - }) - - expect(onEscape).not.toHaveBeenCalled() - unregister() - vi.useRealTimers() - }) - - it('should handle clear command when no timeout is scheduled', async () => { - const { getEditor } = renderOnBlurBlock() - - await waitFor(() => { - expect(getEditor()).not.toBeNull() - }) - - const editor = getEditor() - expect(editor).not.toBeNull() - - let handled = false - act(() => { - handled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) - }) - - expect(handled).toBe(true) - }) - }) - describe('Lifecycle cleanup', () => { it('should unregister commands when component unmounts', async () => { const { getEditor, unmount } = renderOnBlurBlock() @@ -266,16 +167,13 @@ describe('OnBlurBlock', () => { let blurHandled = true let focusHandled = true - let clearHandled = true act(() => { blurHandled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div'))) focusHandled = editor!.dispatchCommand(FOCUS_COMMAND, createFocusEvent()) - clearHandled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) }) expect(blurHandled).toBe(false) expect(focusHandled).toBe(false) - expect(clearHandled).toBe(false) }) }) }) diff --git a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx index 8f6a72a7de..4283910c31 100644 --- a/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/__tests__/update-block.spec.tsx @@ -1,14 +1,13 @@ import type { LexicalEditor } from 'lexical' import { LexicalComposer } from '@lexical/react/LexicalComposer' import { act, render, waitFor } from '@testing-library/react' -import { $getRoot, COMMAND_PRIORITY_EDITOR } from 'lexical' +import { $getRoot } from 'lexical' import { CustomTextNode } from '../custom-text/node' import { CaptureEditorPlugin } from '../test-utils' import UpdateBlock, { PROMPT_EDITOR_INSERT_QUICKLY, PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER, } from '../update-block' -import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block' const { mockUseEventEmitterContextContext } = vi.hoisted(() => ({ mockUseEventEmitterContextContext: vi.fn(), @@ -157,7 +156,7 @@ describe('UpdateBlock', () => { }) describe('Quick insert event', () => { - it('should insert slash and dispatch clear command when quick insert event matches instance id', async () => { + it('should insert slash when quick insert event matches instance id', async () => { const { emit, getEditor } = setup({ instanceId: 'instance-1' }) await waitFor(() => { @@ -168,13 +167,6 @@ describe('UpdateBlock', () => { selectRootEnd(editor!) - const clearCommandHandler = vi.fn(() => true) - const unregister = editor!.registerCommand( - CLEAR_HIDE_MENU_TIMEOUT, - clearCommandHandler, - COMMAND_PRIORITY_EDITOR, - ) - emit({ type: PROMPT_EDITOR_INSERT_QUICKLY, instanceId: 'instance-1', @@ -183,9 +175,6 @@ describe('UpdateBlock', () => { await waitFor(() => { expect(readEditorText(editor!)).toBe('/') }) - expect(clearCommandHandler).toHaveBeenCalledTimes(1) - - unregister() }) it('should ignore quick insert event when instance id does not match', async () => { diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx index 6cc6c3a67f..51b14b76c8 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/__tests__/index.spec.tsx @@ -23,6 +23,8 @@ import { $createTextNode, $getRoot, $setSelection, + BLUR_COMMAND, + FOCUS_COMMAND, KEY_ESCAPE_COMMAND, } from 'lexical' import * as React from 'react' @@ -631,4 +633,180 @@ describe('ComponentPicker (component-picker-block/index.tsx)', () => { // With a single option group, the only divider should be the workflow-var/options separator. expect(document.querySelectorAll('.bg-divider-subtle')).toHaveLength(1) }) + + describe('blur/focus menu visibility', () => { + it('hides the menu after a 200ms delay when blur command is dispatched', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument() + + vi.useRealTimers() + }) + + it('restores menu visibility when focus command is dispatched after blur hides it', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument() + + act(() => { + editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus')) + }) + + vi.useRealTimers() + + await setEditorText(editor, '{', true) + await waitFor(() => { + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + }) + }) + + it('cancels the blur timer when focus arrives before the 200ms timeout', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + act(() => { + editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus')) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + + it('cancels a pending blur timer when a subsequent blur targets var-search-input', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') })) + }) + + const varInput = document.createElement('input') + varInput.classList.add('var-search-input') + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: varInput })) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + + it('does not hide the menu when blur target is var-search-input', async () => { + const captures: Captures = { editor: null, eventEmitter: null } + + render(( + + )) + + const editor = await waitForEditor(captures) + await setEditorText(editor, '{', true) + expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useFakeTimers() + + const target = document.createElement('input') + target.classList.add('var-search-input') + + act(() => { + editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: target })) + }) + + act(() => { + vi.advanceTimersByTime(200) + }) + + expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument() + + vi.useRealTimers() + }) + }) }) diff --git a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx index 8001a2755b..bebc1b59af 100644 --- a/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/component-picker-block/index.tsx @@ -21,11 +21,19 @@ import { } from '@floating-ui/react' import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext' import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin' -import { KEY_ESCAPE_COMMAND } from 'lexical' +import { mergeRegister } from '@lexical/utils' +import { + BLUR_COMMAND, + COMMAND_PRIORITY_EDITOR, + FOCUS_COMMAND, + KEY_ESCAPE_COMMAND, +} from 'lexical' import { Fragment, memo, useCallback, + useEffect, + useRef, useState, } from 'react' import ReactDOM from 'react-dom' @@ -87,6 +95,46 @@ const ComponentPicker = ({ }) const [queryString, setQueryString] = useState(null) + const [blurHidden, setBlurHidden] = useState(false) + const blurTimerRef = useRef | null>(null) + + const clearBlurTimer = useCallback(() => { + if (blurTimerRef.current) { + clearTimeout(blurTimerRef.current) + blurTimerRef.current = null + } + }, []) + + useEffect(() => { + const unregister = mergeRegister( + editor.registerCommand( + BLUR_COMMAND, + (event) => { + clearBlurTimer() + const target = event?.relatedTarget as HTMLElement + if (!target?.classList?.contains('var-search-input')) + blurTimerRef.current = setTimeout(() => setBlurHidden(true), 200) + return false + }, + COMMAND_PRIORITY_EDITOR, + ), + editor.registerCommand( + FOCUS_COMMAND, + () => { + clearBlurTimer() + setBlurHidden(false) + return false + }, + COMMAND_PRIORITY_EDITOR, + ), + ) + + return () => { + if (blurTimerRef.current) + clearTimeout(blurTimerRef.current) + unregister() + } + }, [editor, clearBlurTimer]) eventEmitter?.useSubscription((v: any) => { if (v.type === INSERT_VARIABLE_VALUE_BLOCK_COMMAND) @@ -159,6 +207,8 @@ const ComponentPicker = ({ anchorElementRef, { options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex }, ) => { + if (blurHidden) + return null if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show))) return null @@ -240,7 +290,7 @@ const ComponentPicker = ({ } ) - }, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField]) + }, [blurHidden, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField]) return ( void @@ -20,35 +18,13 @@ const OnBlurBlock: FC = ({ }) => { const [editor] = useLexicalComposerContext() - const ref = useRef | null>(null) - useEffect(() => { - const clearHideMenuTimeout = () => { - if (ref.current) { - clearTimeout(ref.current) - ref.current = null - } - } - - const unregister = mergeRegister( - editor.registerCommand( - CLEAR_HIDE_MENU_TIMEOUT, - () => { - clearHideMenuTimeout() - return true - }, - COMMAND_PRIORITY_EDITOR, - ), + return mergeRegister( editor.registerCommand( BLUR_COMMAND, (event) => { - // Check if the clicked target element is var-search-input const target = event?.relatedTarget as HTMLElement if (!target?.classList?.contains('var-search-input')) { - clearHideMenuTimeout() - ref.current = setTimeout(() => { - editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' })) - }, 200) if (onBlur) onBlur() } @@ -66,11 +42,6 @@ const OnBlurBlock: FC = ({ COMMAND_PRIORITY_EDITOR, ), ) - - return () => { - clearHideMenuTimeout() - unregister() - } }, [editor, onBlur, onFocus]) return null diff --git a/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx b/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx index abe6ea9a45..7dcda803f2 100644 --- a/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/shortcuts-popup-plugin/index.tsx @@ -141,7 +141,7 @@ export default function ShortcutsPopupPlugin({ const portalRef = useRef(null) const lastSelectionRef = useRef(null) - /* v8 ignore next -- defensive non-browser fallback; this client-only plugin runs where document exists (browser/jsdom). @preserve */ + /* v8 ignore next -- defensive non-browser fallback; this client-only plugin runs where document exists (browser/test DOM runtime). @preserve */ const containerEl = useMemo(() => container ?? (typeof document !== 'undefined' ? document.body : null), [container]) const useContainer = !!containerEl && containerEl !== document.body @@ -210,7 +210,7 @@ export default function ShortcutsPopupPlugin({ if (rect.width === 0 && rect.height === 0) { const root = editor.getRootElement() - /* v8 ignore next 10 -- zero-size rect recovery depends on browser layout/selection geometry; deterministic reproduction in jsdom is unreliable. @preserve */ + /* v8 ignore next 10 -- zero-size rect recovery depends on browser layout/selection geometry; deterministic reproduction in the test DOM runtime is unreliable. @preserve */ if (root) { const sc = range.startContainer const node = sc.nodeType === Node.ELEMENT_NODE diff --git a/web/app/components/base/prompt-editor/plugins/update-block.tsx b/web/app/components/base/prompt-editor/plugins/update-block.tsx index bf89a259af..2d83573b1f 100644 --- a/web/app/components/base/prompt-editor/plugins/update-block.tsx +++ b/web/app/components/base/prompt-editor/plugins/update-block.tsx @@ -3,7 +3,6 @@ import { $insertNodes } from 'lexical' import { useEventEmitterContextContext } from '@/context/event-emitter' import { textToEditorState } from '../utils' import { CustomTextNode } from './custom-text/node' -import { CLEAR_HIDE_MENU_TIMEOUT } from './workflow-variable-block' export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER' export const PROMPT_EDITOR_INSERT_QUICKLY = 'PROMPT_EDITOR_INSERT_QUICKLY' @@ -30,8 +29,6 @@ const UpdateBlock = ({ editor.update(() => { const textNode = new CustomTextNode('/') $insertNodes([textNode]) - - editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) }) } }) diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx index ca4973b830..1591dc44f9 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/__tests__/index.spec.tsx @@ -9,7 +9,6 @@ import { $insertNodes, COMMAND_PRIORITY_EDITOR } from 'lexical' import { Type } from '@/app/components/workflow/nodes/llm/types' import { BlockEnum } from '@/app/components/workflow/types' import { - CLEAR_HIDE_MENU_TIMEOUT, DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND, INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, UPDATE_WORKFLOW_NODES_MAP, @@ -134,7 +133,6 @@ describe('WorkflowVariableBlock', () => { const insertHandler = mockRegisterCommand.mock.calls[0][1] as (variables: string[]) => boolean const result = insertHandler(['node-1', 'answer']) - expect(mockDispatchCommand).toHaveBeenCalledWith(CLEAR_HIDE_MENU_TIMEOUT, undefined) expect($createWorkflowVariableBlockNode).toHaveBeenCalledWith( ['node-1', 'answer'], workflowNodesMap, diff --git a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx index 76b2795803..c8cac64d19 100644 --- a/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx +++ b/web/app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx @@ -18,7 +18,6 @@ import { export const INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND') export const DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND') -export const CLEAR_HIDE_MENU_TIMEOUT = createCommand('CLEAR_HIDE_MENU_TIMEOUT') export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP') export type WorkflowVariableBlockProps = { @@ -49,7 +48,6 @@ const WorkflowVariableBlock = memo(({ editor.registerCommand( INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND, (variables: string[]) => { - editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined) const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType) $insertNodes([workflowVariableBlockNode]) diff --git a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx index 7f292c8ff9..0a3470420c 100644 --- a/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/__tests__/index.spec.tsx @@ -1612,9 +1612,7 @@ describe('Uploader', () => { if (!dropArea) return - fireEvent.drop(dropArea, { - dataTransfer: null, - }) + fireEvent.drop(dropArea) expect(updateFile).not.toHaveBeenCalled() }) diff --git a/web/app/components/datasets/create/website/watercrawl/__tests__/index.spec.tsx b/web/app/components/datasets/create/website/watercrawl/__tests__/index.spec.tsx index 5ff2d8efb8..e9d933cc03 100644 --- a/web/app/components/datasets/create/website/watercrawl/__tests__/index.spec.tsx +++ b/web/app/components/datasets/create/website/watercrawl/__tests__/index.spec.tsx @@ -1,6 +1,3 @@ -/** - * @vitest-environment jsdom - */ import type { Mock } from 'vitest' import type { CrawlOptions, CrawlResultItem } from '@/models/datasets' import { fireEvent, render, screen, waitFor } from '@testing-library/react' diff --git a/web/app/components/develop/__tests__/code.spec.tsx b/web/app/components/develop/__tests__/code.spec.tsx index 452e6ea98f..e5eaebb600 100644 --- a/web/app/components/develop/__tests__/code.spec.tsx +++ b/web/app/components/develop/__tests__/code.spec.tsx @@ -11,7 +11,7 @@ describe('code.tsx components', () => { vi.clearAllMocks() vi.spyOn(console, 'error').mockImplementation(() => {}) vi.useFakeTimers({ shouldAdvanceTime: true }) - // jsdom does not implement scrollBy; mock it to prevent stderr noise + // The test DOM runtime does not implement scrollBy; mock it to prevent stderr noise window.scrollBy = vi.fn() }) diff --git a/web/app/components/develop/__tests__/use-doc-toc.spec.ts b/web/app/components/develop/__tests__/use-doc-toc.spec.ts index e437e13065..b20c2c8ecf 100644 --- a/web/app/components/develop/__tests__/use-doc-toc.spec.ts +++ b/web/app/components/develop/__tests__/use-doc-toc.spec.ts @@ -307,7 +307,7 @@ describe('useDocToc', () => { it('should update activeSection when scrolling past a section', async () => { vi.useFakeTimers() - // innerHeight/2 = 384 in jsdom (default 768), so top <= 384 means "scrolled past" + // innerHeight/2 = 384 with the default test viewport height (768), so top <= 384 means "scrolled past" const { scrollContainer, cleanup } = setupScrollDOM([ { id: 'intro', text: 'Intro', top: 100 }, { id: 'details', text: 'Details', top: 600 }, diff --git a/web/app/components/header/account-setting/model-provider-page/model-auth/__tests__/add-custom-model.spec.tsx b/web/app/components/header/account-setting/model-provider-page/model-auth/__tests__/add-custom-model.spec.tsx index 6117420afa..43a27dac9b 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-auth/__tests__/add-custom-model.spec.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-auth/__tests__/add-custom-model.spec.tsx @@ -43,7 +43,7 @@ vi.mock('@/app/components/base/tooltip', () => ({ ), })) -// Mock portal components to avoid async/jsdom issues (consistent with sibling tests) +// Mock portal components to avoid async test DOM issues (consistent with sibling tests) vi.mock('@/app/components/base/portal-to-follow-elem', () => ({ PortalToFollowElem: ({ children, open }: { children: React.ReactNode, open: boolean, onOpenChange: (open: boolean) => void }) => (
diff --git a/web/app/components/header/github-star/__tests__/index.spec.tsx b/web/app/components/header/github-star/__tests__/index.spec.tsx index 1790f31542..b800622137 100644 --- a/web/app/components/header/github-star/__tests__/index.spec.tsx +++ b/web/app/components/header/github-star/__tests__/index.spec.tsx @@ -1,11 +1,8 @@ import { QueryClient, QueryClientProvider } from '@tanstack/react-query' import { render, screen, waitFor } from '@testing-library/react' -import nock from 'nock' -import * as React from 'react' import GithubStar from '../index' -const GITHUB_HOST = 'https://api.github.com' -const GITHUB_PATH = '/repos/langgenius/dify' +const GITHUB_STAR_URL = 'https://ungh.cc/repos/langgenius/dify' const renderWithQueryClient = () => { const queryClient = new QueryClient({ @@ -18,40 +15,66 @@ const renderWithQueryClient = () => { ) } -const mockGithubStar = (status: number, body: Record, delayMs = 0) => { - return nock(GITHUB_HOST).get(GITHUB_PATH).delay(delayMs).reply(status, body) +const createJsonResponse = (body: Record, status = 200) => { + return new Response(JSON.stringify(body), { + status, + headers: { 'Content-Type': 'application/json' }, + }) +} + +const createDeferred = () => { + let resolve!: (value: T | PromiseLike) => void + let reject!: (reason?: unknown) => void + const promise = new Promise((res, rej) => { + resolve = res + reject = rej + }) + + return { promise, resolve, reject } } describe('GithubStar', () => { beforeEach(() => { - nock.cleanAll() + vi.restoreAllMocks() + vi.clearAllMocks() }) - // Shows fetched star count when request succeeds + // Covers the fetched star count shown after a successful request. it('should render fetched star count', async () => { - mockGithubStar(200, { stargazers_count: 123456 }) + const fetchSpy = vi.spyOn(globalThis, 'fetch').mockResolvedValue( + createJsonResponse({ repo: { stars: 123456 } }), + ) renderWithQueryClient() expect(await screen.findByText('123,456')).toBeInTheDocument() + expect(fetchSpy).toHaveBeenCalledWith(GITHUB_STAR_URL) }) - // Falls back to default star count when request fails + // Covers the fallback star count shown when the request fails. it('should render default star count on error', async () => { - mockGithubStar(500, {}) + vi.spyOn(globalThis, 'fetch').mockResolvedValue( + createJsonResponse({}, 500), + ) renderWithQueryClient() expect(await screen.findByText('110,918')).toBeInTheDocument() }) - // Renders loader while fetching data + // Covers the loading indicator while the fetch promise is still pending. it('should show loader while fetching', async () => { - mockGithubStar(200, { stargazers_count: 222222 }, 50) + const deferred = createDeferred() + vi.spyOn(globalThis, 'fetch').mockReturnValueOnce(deferred.promise) const { container } = renderWithQueryClient() expect(container.querySelector('.animate-spin')).toBeInTheDocument() - await waitFor(() => expect(screen.getByText('222,222')).toBeInTheDocument()) + + deferred.resolve(createJsonResponse({ repo: { stars: 222222 } })) + + await waitFor(() => { + expect(screen.getByText('222,222')).toBeInTheDocument() + }) }) }) diff --git a/web/app/components/header/github-star/index.tsx b/web/app/components/header/github-star/index.tsx index e91bdcca2c..44e8c5ac6f 100644 --- a/web/app/components/header/github-star/index.tsx +++ b/web/app/components/header/github-star/index.tsx @@ -1,16 +1,19 @@ 'use client' import type { FC } from 'react' -import type { GithubRepo } from '@/models/common' -import { RiLoader2Line } from '@remixicon/react' import { useQuery } from '@tanstack/react-query' -import { IS_DEV } from '@/config' -const defaultData = { - stargazers_count: 110918, +type GithubStarResponse = { + repo: { + stars: number + } +} + +const defaultData: GithubStarResponse = { + repo: { stars: 110918 }, } const getStar = async () => { - const res = await fetch('https://api.github.com/repos/langgenius/dify') + const res = await fetch('https://ungh.cc/repos/langgenius/dify') if (!res.ok) throw new Error('Failed to fetch github star') @@ -19,21 +22,20 @@ const getStar = async () => { } const GithubStar: FC<{ className: string }> = (props) => { - const { isFetching, isError, data } = useQuery({ + const { isFetching, isError, data } = useQuery({ queryKey: ['github-star'], queryFn: getStar, - enabled: !IS_DEV, retry: false, placeholderData: defaultData, }) if (isFetching) - return + return if (isError) - return {defaultData.stargazers_count.toLocaleString()} + return {defaultData.repo.stars.toLocaleString()} - return {data?.stargazers_count.toLocaleString()} + return {data?.repo.stars.toLocaleString()} } export default GithubStar diff --git a/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts b/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts index 6b0fc27adf..b4171de7f0 100644 --- a/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts +++ b/web/app/components/plugins/install-plugin/__tests__/hooks.spec.ts @@ -1,6 +1,5 @@ -import { renderHook } from '@testing-library/react' import { beforeEach, describe, expect, it, vi } from 'vitest' -import { useGitHubReleases, useGitHubUpload } from '../hooks' +import { checkForUpdates, fetchReleases, handleUpload } from '../hooks' const mockNotify = vi.fn() vi.mock('@/app/components/base/ui/toast', () => ({ @@ -15,10 +14,6 @@ vi.mock('@/app/components/base/ui/toast', () => ({ }), })) -vi.mock('@/config', () => ({ - GITHUB_ACCESS_TOKEN: '', -})) - const mockUploadGitHub = vi.fn() vi.mock('@/service/plugins', () => ({ uploadGitHub: (...args: unknown[]) => mockUploadGitHub(...args), @@ -37,17 +32,17 @@ describe('install-plugin/hooks', () => { it('fetches releases from GitHub API and formats them', async () => { mockFetch.mockResolvedValue({ ok: true, - json: () => Promise.resolve([ - { - tag_name: 'v1.0.0', - assets: [{ browser_download_url: 'https://example.com/v1.zip', name: 'plugin.zip' }], - body: 'Release notes', - }, - ]), + json: () => Promise.resolve({ + releases: [ + { + tag: 'v1.0.0', + assets: [{ downloadUrl: 'https://example.com/plugin.zip' }], + }, + ], + }), }) - const { result } = renderHook(() => useGitHubReleases()) - const releases = await result.current.fetchReleases('owner', 'repo') + const releases = await fetchReleases('owner', 'repo') expect(releases).toHaveLength(1) expect(releases[0].tag_name).toBe('v1.0.0') @@ -60,8 +55,7 @@ describe('install-plugin/hooks', () => { ok: false, }) - const { result } = renderHook(() => useGitHubReleases()) - const releases = await result.current.fetchReleases('owner', 'repo') + const releases = await fetchReleases('owner', 'repo') expect(releases).toEqual([]) expect(mockNotify).toHaveBeenCalledWith('Failed to fetch repository releases') @@ -70,29 +64,26 @@ describe('install-plugin/hooks', () => { describe('checkForUpdates', () => { it('detects newer version available', () => { - const { result } = renderHook(() => useGitHubReleases()) const releases = [ { tag_name: 'v1.0.0', assets: [] }, { tag_name: 'v2.0.0', assets: [] }, ] - const { needUpdate, toastProps } = result.current.checkForUpdates(releases, 'v1.0.0') + const { needUpdate, toastProps } = checkForUpdates(releases, 'v1.0.0') expect(needUpdate).toBe(true) expect(toastProps.message).toContain('v2.0.0') }) it('returns no update when current is latest', () => { - const { result } = renderHook(() => useGitHubReleases()) const releases = [ { tag_name: 'v1.0.0', assets: [] }, ] - const { needUpdate, toastProps } = result.current.checkForUpdates(releases, 'v1.0.0') + const { needUpdate, toastProps } = checkForUpdates(releases, 'v1.0.0') expect(needUpdate).toBe(false) expect(toastProps.type).toBe('info') }) it('returns error for empty releases', () => { - const { result } = renderHook(() => useGitHubReleases()) - const { needUpdate, toastProps } = result.current.checkForUpdates([], 'v1.0.0') + const { needUpdate, toastProps } = checkForUpdates([], 'v1.0.0') expect(needUpdate).toBe(false) expect(toastProps.type).toBe('error') expect(toastProps.message).toContain('empty') @@ -109,8 +100,7 @@ describe('install-plugin/hooks', () => { }) const onSuccess = vi.fn() - const { result } = renderHook(() => useGitHubUpload()) - const pkg = await result.current.handleUpload( + const pkg = await handleUpload( 'https://github.com/owner/repo', 'v1.0.0', 'plugin.difypkg', @@ -132,9 +122,8 @@ describe('install-plugin/hooks', () => { it('shows toast on upload error', async () => { mockUploadGitHub.mockRejectedValue(new Error('Upload failed')) - const { result } = renderHook(() => useGitHubUpload()) await expect( - result.current.handleUpload('url', 'v1', 'pkg'), + handleUpload('url', 'v1', 'pkg'), ).rejects.toThrow('Upload failed') expect(mockNotify).toHaveBeenCalledWith('Error uploading package') }) diff --git a/web/app/components/plugins/install-plugin/hooks.ts b/web/app/components/plugins/install-plugin/hooks.ts index cc7148cc17..f86e6ad672 100644 --- a/web/app/components/plugins/install-plugin/hooks.ts +++ b/web/app/components/plugins/install-plugin/hooks.ts @@ -1,101 +1,87 @@ import type { GitHubRepoReleaseResponse } from '../types' import { toast } from '@/app/components/base/ui/toast' -import { GITHUB_ACCESS_TOKEN } from '@/config' import { uploadGitHub } from '@/service/plugins' import { compareVersion, getLatestVersion } from '@/utils/semver' +const normalizeAssetName = (downloadUrl: string) => { + const parts = downloadUrl.split('/') + return parts[parts.length - 1] +} + const formatReleases = (releases: any) => { return releases.map((release: any) => ({ - tag_name: release.tag_name, + tag_name: release.tag, assets: release.assets.map((asset: any) => ({ - browser_download_url: asset.browser_download_url, - name: asset.name, + browser_download_url: asset.downloadUrl, + name: normalizeAssetName(asset.downloadUrl), })), })) } -export const useGitHubReleases = () => { - const fetchReleases = async (owner: string, repo: string) => { - try { - if (!GITHUB_ACCESS_TOKEN) { - // Fetch releases without authentication from client - const res = await fetch(`https://api.github.com/repos/${owner}/${repo}/releases`) - if (!res.ok) - throw new Error('Failed to fetch repository releases') - const data = await res.json() - return formatReleases(data) - } - else { - // Fetch releases with authentication from server - const res = await fetch(`/repos/${owner}/${repo}/releases`) - const bodyJson = await res.json() - if (bodyJson.status !== 200) - throw new Error(bodyJson.data.message) - return formatReleases(bodyJson.data) - } - } - catch (error) { - if (error instanceof Error) { - toast.error(error.message) - } - else { - toast.error('Failed to fetch repository releases') - } - return [] - } +export const fetchReleases = async (owner: string, repo: string) => { + try { + // Fetch releases without authentication from client + const res = await fetch(`https://ungh.cc/repos/${owner}/${repo}/releases`) + if (!res.ok) + throw new Error('Failed to fetch repository releases') + const data = await res.json() + return formatReleases(data.releases) } + catch (error) { + if (error instanceof Error) { + toast.error(error.message) + } + else { + toast.error('Failed to fetch repository releases') + } + return [] + } +} - const checkForUpdates = (fetchedReleases: GitHubRepoReleaseResponse[], currentVersion: string) => { - let needUpdate = false - const toastProps: { type?: 'success' | 'error' | 'info' | 'warning', message: string } = { - type: 'info', - message: 'No new version available', - } - if (fetchedReleases.length === 0) { - toastProps.type = 'error' - toastProps.message = 'Input releases is empty' - return { needUpdate, toastProps } - } - const versions = fetchedReleases.map(release => release.tag_name) - const latestVersion = getLatestVersion(versions) - try { - needUpdate = compareVersion(latestVersion, currentVersion) === 1 - if (needUpdate) - toastProps.message = `New version available: ${latestVersion}` - } - catch { - needUpdate = false - toastProps.type = 'error' - toastProps.message = 'Fail to compare versions, please check the version format' - } +export const checkForUpdates = (fetchedReleases: GitHubRepoReleaseResponse[], currentVersion: string) => { + let needUpdate = false + const toastProps: { type?: 'success' | 'error' | 'info' | 'warning', message: string } = { + type: 'info', + message: 'No new version available', + } + if (fetchedReleases.length === 0) { + toastProps.type = 'error' + toastProps.message = 'Input releases is empty' return { needUpdate, toastProps } } - - return { fetchReleases, checkForUpdates } -} - -export const useGitHubUpload = () => { - const handleUpload = async ( - repoUrl: string, - selectedVersion: string, - selectedPackage: string, - onSuccess?: (GitHubPackage: { manifest: any, unique_identifier: string }) => void, - ) => { - try { - const response = await uploadGitHub(repoUrl, selectedVersion, selectedPackage) - const GitHubPackage = { - manifest: response.manifest, - unique_identifier: response.unique_identifier, - } - if (onSuccess) - onSuccess(GitHubPackage) - return GitHubPackage - } - catch (error) { - toast.error('Error uploading package') - throw error - } + const versions = fetchedReleases.map(release => release.tag_name) + const latestVersion = getLatestVersion(versions) + try { + needUpdate = compareVersion(latestVersion, currentVersion) === 1 + if (needUpdate) + toastProps.message = `New version available: ${latestVersion}` + } + catch { + needUpdate = false + toastProps.type = 'error' + toastProps.message = 'Fail to compare versions, please check the version format' + } + return { needUpdate, toastProps } +} + +export const handleUpload = async ( + repoUrl: string, + selectedVersion: string, + selectedPackage: string, + onSuccess?: (GitHubPackage: { manifest: any, unique_identifier: string }) => void, +) => { + try { + const response = await uploadGitHub(repoUrl, selectedVersion, selectedPackage) + const GitHubPackage = { + manifest: response.manifest, + unique_identifier: response.unique_identifier, + } + if (onSuccess) + onSuccess(GitHubPackage) + return GitHubPackage + } + catch (error) { + toast.error('Error uploading package') + throw error } - - return { handleUpload } } diff --git a/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx b/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx index 8abec7817b..2f35f484f0 100644 --- a/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx +++ b/web/app/components/plugins/install-plugin/install-from-github/__tests__/index.spec.tsx @@ -74,10 +74,16 @@ vi.mock('@/app/components/plugins/install-plugin/base/use-get-icon', () => ({ default: () => ({ getIconUrl: mockGetIconUrl }), })) -const mockFetchReleases = vi.fn() -vi.mock('../../hooks', () => ({ - useGitHubReleases: () => ({ fetchReleases: mockFetchReleases }), +const { mockFetchReleases } = vi.hoisted(() => ({ + mockFetchReleases: vi.fn(), })) +vi.mock('../../hooks', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + fetchReleases: mockFetchReleases, + } +}) const mockRefreshPluginList = vi.fn() vi.mock('../../hooks/use-refresh-plugin-list', () => ({ diff --git a/web/app/components/plugins/install-plugin/install-from-github/index.tsx b/web/app/components/plugins/install-plugin/install-from-github/index.tsx index ff51698478..c7ace3d94f 100644 --- a/web/app/components/plugins/install-plugin/install-from-github/index.tsx +++ b/web/app/components/plugins/install-plugin/install-from-github/index.tsx @@ -12,7 +12,7 @@ import useGetIcon from '@/app/components/plugins/install-plugin/base/use-get-ico import { cn } from '@/utils/classnames' import { InstallStepFromGitHub } from '../../types' import Installed from '../base/installed' -import { useGitHubReleases } from '../hooks' +import { fetchReleases } from '../hooks' import useHideLogic from '../hooks/use-hide-logic' import useRefreshPluginList from '../hooks/use-refresh-plugin-list' import { convertRepoToUrl, parseGitHubUrl } from '../utils' @@ -31,7 +31,6 @@ type InstallFromGitHubProps = { const InstallFromGitHub: React.FC = ({ updatePayload, onClose, onSuccess }) => { const { t } = useTranslation() const { getIconUrl } = useGetIcon() - const { fetchReleases } = useGitHubReleases() const { refreshPluginList } = useRefreshPluginList() const { diff --git a/web/app/components/plugins/install-plugin/install-from-github/steps/__tests__/selectPackage.spec.tsx b/web/app/components/plugins/install-plugin/install-from-github/steps/__tests__/selectPackage.spec.tsx index 060a5c92a1..cd40f93339 100644 --- a/web/app/components/plugins/install-plugin/install-from-github/steps/__tests__/selectPackage.spec.tsx +++ b/web/app/components/plugins/install-plugin/install-from-github/steps/__tests__/selectPackage.spec.tsx @@ -5,11 +5,17 @@ import { beforeEach, describe, expect, it, vi } from 'vitest' import { PluginCategoryEnum } from '../../../../types' import SelectPackage from '../selectPackage' -// Mock the useGitHubUpload hook -const mockHandleUpload = vi.fn() -vi.mock('../../../hooks', () => ({ - useGitHubUpload: () => ({ handleUpload: mockHandleUpload }), +// Mock upload helper from hooks module +const { mockHandleUpload } = vi.hoisted(() => ({ + mockHandleUpload: vi.fn(), })) +vi.mock('../../../hooks', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, + handleUpload: mockHandleUpload, + } +}) // Factory functions const createMockManifest = (): PluginDeclaration => ({ diff --git a/web/app/components/plugins/install-plugin/install-from-github/steps/selectPackage.tsx b/web/app/components/plugins/install-plugin/install-from-github/steps/selectPackage.tsx index 40e32e6c3b..94339eaa97 100644 --- a/web/app/components/plugins/install-plugin/install-from-github/steps/selectPackage.tsx +++ b/web/app/components/plugins/install-plugin/install-from-github/steps/selectPackage.tsx @@ -6,7 +6,7 @@ import * as React from 'react' import { useTranslation } from 'react-i18next' import Button from '@/app/components/base/button' import { PortalSelect } from '@/app/components/base/select' -import { useGitHubUpload } from '../../hooks' +import { handleUpload } from '../../hooks' const i18nPrefix = 'installFromGitHub' @@ -43,7 +43,6 @@ const SelectPackage: React.FC = ({ const { t } = useTranslation() const isEdit = Boolean(updatePayload) const [isUploading, setIsUploading] = React.useState(false) - const { handleUpload } = useGitHubUpload() const handleUploadPackage = async () => { if (isUploading) diff --git a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx index d8ab1cafdd..38ce5cca4e 100644 --- a/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/api-key-modal.tsx @@ -142,7 +142,7 @@ const ApiKeyModal = ({ onExtraButtonClick={onRemove} disabled={disabled || isLoading || doingAction} clickOutsideNotClose={true} - wrapperClassName="!z-[101]" + wrapperClassName="!z-[1002]" > {pluginPayload.detail && ( diff --git a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx index 28989da77c..a69cc1538b 100644 --- a/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx +++ b/web/app/components/plugins/plugin-auth/authorize/oauth-client-settings.tsx @@ -157,7 +157,7 @@ const OAuthClientSettings = ({ ) } containerClassName="pt-0" - wrapperClassName="!z-[101]" + wrapperClassName="!z-[1002]" clickOutsideNotClose={true} > {pluginPayload.detail && ( diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx index f8d6488128..e0aa13948b 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/detail-header.spec.tsx @@ -103,12 +103,14 @@ vi.mock('@/service/use-tools', () => ({ useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders, })) -vi.mock('../../install-plugin/hooks', () => ({ - useGitHubReleases: () => ({ +vi.mock('../../install-plugin/hooks', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, checkForUpdates: mockCheckForUpdates, fetchReleases: mockFetchReleases, - }), -})) + } +}) // Auto upgrade settings mock let mockAutoUpgradeInfo: { diff --git a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx index 237c72adf0..2af14c5864 100644 --- a/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/__tests__/endpoint-card.spec.tsx @@ -142,7 +142,7 @@ describe('EndpointCard', () => { failureFlags.disable = false failureFlags.delete = false failureFlags.update = false - // Polyfill document.execCommand for copy-to-clipboard in jsdom + // Polyfill document.execCommand for copy-to-clipboard in the test DOM runtime if (typeof document.execCommand !== 'function') { document.execCommand = vi.fn().mockReturnValue(true) } diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts index 77d41c5bce..b8873f1087 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/__tests__/use-plugin-operations.spec.ts @@ -72,12 +72,14 @@ vi.mock('@/service/use-tools', () => ({ useInvalidateAllToolProviders: () => mockInvalidateAllToolProviders, })) -vi.mock('../../../../install-plugin/hooks', () => ({ - useGitHubReleases: () => ({ +vi.mock('../../../../install-plugin/hooks', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, checkForUpdates: mockCheckForUpdates, fetchReleases: mockFetchReleases, - }), -})) + } +}) const createPluginDetail = (overrides: Partial = {}): PluginDetail => ({ id: 'test-id', diff --git a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts index ade47cec5f..765c0e8a4e 100644 --- a/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts +++ b/web/app/components/plugins/plugin-detail-panel/detail-header/hooks/use-plugin-operations.ts @@ -11,7 +11,7 @@ import { useProviderContext } from '@/context/provider-context' import { uninstallPlugin } from '@/service/plugins' import { useInvalidateCheckInstalled } from '@/service/use-plugins' import { useInvalidateAllToolProviders } from '@/service/use-tools' -import { useGitHubReleases } from '../../../install-plugin/hooks' +import { checkForUpdates, fetchReleases } from '../../../install-plugin/hooks' import { PluginCategoryEnum, PluginSource } from '../../../types' type UsePluginOperationsParams = { @@ -39,7 +39,6 @@ export const usePluginOperations = ({ onUpdate, }: UsePluginOperationsParams): UsePluginOperationsReturn => { const { t } = useTranslation() - const { checkForUpdates, fetchReleases } = useGitHubReleases() const { setShowUpdatePluginModal } = useModalContext() const { refreshModelProviders } = useProviderContext() const invalidateCheckInstalled = useInvalidateCheckInstalled() diff --git a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx index ce53bf5b9a..5c4407b3c5 100644 --- a/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx +++ b/web/app/components/plugins/plugin-detail-panel/subscription-list/create/__tests__/oauth-client.spec.tsx @@ -102,10 +102,12 @@ vi.mock('@/app/components/base/ui/toast', () => ({ })) const mockClipboardWriteText = vi.fn() -Object.assign(navigator, { - clipboard: { +Object.defineProperty(navigator, 'clipboard', { + value: { writeText: mockClipboardWriteText, }, + configurable: true, + writable: true, }) vi.mock('@/app/components/base/modal/modal', () => ({ @@ -192,6 +194,13 @@ describe('OAuthClientSettingsModal', () => { vi.clearAllMocks() mockUsePluginStore.mockReturnValue(mockPluginDetail) mockClipboardWriteText.mockResolvedValue(undefined) + Object.defineProperty(navigator, 'clipboard', { + value: { + writeText: mockClipboardWriteText, + }, + configurable: true, + writable: true, + }) setMockFormValues({ values: { client_id: 'test-client-id', client_secret: 'test-client-secret' }, isCheckValidated: true, diff --git a/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx b/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx index b4d21c9403..45f5deba22 100644 --- a/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx +++ b/web/app/components/plugins/plugin-item/__tests__/action.spec.tsx @@ -46,13 +46,15 @@ vi.mock('@/service/plugins', () => ({ uninstallPlugin: (id: string) => mockUninstallPlugin(id), })) -// Mock GitHub releases hook -vi.mock('../../install-plugin/hooks', () => ({ - useGitHubReleases: () => ({ +// Mock GitHub release helpers +vi.mock('../../install-plugin/hooks', async (importOriginal) => { + const actual = await importOriginal() + return { + ...actual, fetchReleases: mockFetchReleases, checkForUpdates: mockCheckForUpdates, - }), -})) + } +}) // Mock modal context vi.mock('@/context/modal-context', () => ({ diff --git a/web/app/components/plugins/plugin-item/action.tsx b/web/app/components/plugins/plugin-item/action.tsx index 413b41e895..c01be54442 100644 --- a/web/app/components/plugins/plugin-item/action.tsx +++ b/web/app/components/plugins/plugin-item/action.tsx @@ -14,7 +14,7 @@ import { useInvalidateInstalledPluginList } from '@/service/use-plugins' import ActionButton from '../../base/action-button' import Confirm from '../../base/confirm' import Tooltip from '../../base/tooltip' -import { useGitHubReleases } from '../install-plugin/hooks' +import { checkForUpdates, fetchReleases } from '../install-plugin/hooks' import PluginInfo from '../plugin-page/plugin-info' import { PluginSource } from '../types' @@ -54,7 +54,6 @@ const Action: FC = ({ setTrue: showDeleting, setFalse: hideDeleting, }] = useBoolean(false) - const { checkForUpdates, fetchReleases } = useGitHubReleases() const { setShowUpdatePluginModal } = useModalContext() const invalidateInstalledPluginList = useInvalidateInstalledPluginList() diff --git a/web/app/components/rag-pipeline/components/chunk-card-list/types.ts b/web/app/components/rag-pipeline/components/chunk-card-list/types.ts index 6117855b3b..b1213917e4 100644 --- a/web/app/components/rag-pipeline/components/chunk-card-list/types.ts +++ b/web/app/components/rag-pipeline/components/chunk-card-list/types.ts @@ -26,7 +26,8 @@ export type QAChunks = { export type ChunkInfo = GeneralChunks | ParentChildChunks | QAChunks -export enum QAItemType { - Question = 'question', - Answer = 'answer', -} +export const QAItemType = { + Question: 'question', + Answer: 'answer', +} as const +export type QAItemType = typeof QAItemType[keyof typeof QAItemType] diff --git a/web/app/components/rag-pipeline/components/panel/test-run/types.ts b/web/app/components/rag-pipeline/components/panel/test-run/types.ts index ac5ca3ced2..d129487fe1 100644 --- a/web/app/components/rag-pipeline/components/panel/test-run/types.ts +++ b/web/app/components/rag-pipeline/components/panel/test-run/types.ts @@ -1,9 +1,10 @@ import type { DataSourceNodeType } from '@/app/components/workflow/nodes/data-source/types' -export enum TestRunStep { - dataSource = 'dataSource', - documentProcessing = 'documentProcessing', -} +export const TestRunStep = { + dataSource: 'dataSource', + documentProcessing: 'documentProcessing', +} as const +export type TestRunStep = typeof TestRunStep[keyof typeof TestRunStep] export type DataSourceOption = { label: string diff --git a/web/app/components/workflow/__tests__/custom-edge-linear-gradient-render.spec.tsx b/web/app/components/workflow/__tests__/custom-edge-linear-gradient-render.spec.tsx index e962923158..973dfacbc8 100644 --- a/web/app/components/workflow/__tests__/custom-edge-linear-gradient-render.spec.tsx +++ b/web/app/components/workflow/__tests__/custom-edge-linear-gradient-render.spec.tsx @@ -48,10 +48,10 @@ describe('CustomEdgeLinearGradientRender', () => { const stops = container.querySelectorAll('stop') expect(stops).toHaveLength(2) expect(stops[0]).toHaveAttribute('offset', '0%') - expect(stops[0].getAttribute('style')).toContain('stop-color: rgb(17, 17, 17)') + expect(stops[0].getAttribute('style')).toContain('stop-color: #111111') expect(stops[0].getAttribute('style')).toContain('stop-opacity: 1') expect(stops[1]).toHaveAttribute('offset', '100%') - expect(stops[1].getAttribute('style')).toContain('stop-color: rgb(34, 34, 34)') + expect(stops[1].getAttribute('style')).toContain('stop-color: #222222') expect(stops[1].getAttribute('style')).toContain('stop-opacity: 1') }) }) diff --git a/web/app/components/workflow/__tests__/selection-contextmenu.spec.tsx b/web/app/components/workflow/__tests__/selection-contextmenu.spec.tsx index 247184349d..1106cfcb75 100644 --- a/web/app/components/workflow/__tests__/selection-contextmenu.spec.tsx +++ b/web/app/components/workflow/__tests__/selection-contextmenu.spec.tsx @@ -10,6 +10,9 @@ import { renderWorkflowFlowComponent } from './workflow-test-env' let latestNodes: Node[] = [] let latestHistoryEvent: string | undefined const mockGetNodesReadOnly = vi.fn() +const mockHandleNodesCopy = vi.fn() +const mockHandleNodesDuplicate = vi.fn() +const mockHandleNodesDelete = vi.fn() vi.mock('../hooks', async () => { const actual = await vi.importActual('../hooks') @@ -18,6 +21,11 @@ vi.mock('../hooks', async () => { useNodesReadOnly: () => ({ getNodesReadOnly: mockGetNodesReadOnly, }), + useNodesInteractions: () => ({ + handleNodesCopy: mockHandleNodesCopy, + handleNodesDuplicate: mockHandleNodesDuplicate, + handleNodesDelete: mockHandleNodesDelete, + }), } }) @@ -73,6 +81,9 @@ describe('SelectionContextmenu', () => { latestHistoryEvent = undefined mockGetNodesReadOnly.mockReset() mockGetNodesReadOnly.mockReturnValue(false) + mockHandleNodesCopy.mockReset() + mockHandleNodesDuplicate.mockReset() + mockHandleNodesDelete.mockReset() }) it('should not render when selectionMenu is absent', () => { @@ -81,7 +92,7 @@ describe('SelectionContextmenu', () => { expect(screen.queryByText('operator.vertical')).not.toBeInTheDocument() }) - it('should keep the menu inside the workflow container bounds', () => { + it('should render menu items when selectionMenu is present', async () => { const nodes = [ createNode({ id: 'n1', selected: true, width: 80, height: 40 }), createNode({ id: 'n2', selected: true, position: { x: 140, y: 0 }, width: 80, height: 40 }), @@ -89,11 +100,46 @@ describe('SelectionContextmenu', () => { const { store } = renderSelectionMenu({ nodes }) act(() => { - store.setState({ selectionMenu: { left: 780, top: 590 } }) + store.setState({ selectionMenu: { clientX: 780, clientY: 590 } }) }) - const menu = screen.getByTestId('selection-contextmenu') - expect(menu).toHaveStyle({ left: '540px', top: '210px' }) + await waitFor(() => { + expect(screen.getByTestId('selection-contextmenu-item-left')).toBeInTheDocument() + }) + }) + + it('should render and execute copy/duplicate/delete operations', async () => { + const nodes = [ + createNode({ id: 'n1', selected: true, width: 80, height: 40 }), + createNode({ id: 'n2', selected: true, position: { x: 140, y: 0 }, width: 80, height: 40 }), + ] + const { store } = renderSelectionMenu({ nodes }) + + act(() => { + store.setState({ selectionMenu: { clientX: 120, clientY: 120 } }) + }) + + await waitFor(() => { + expect(screen.getByTestId('selection-contextmenu-item-copy')).toBeInTheDocument() + }) + + fireEvent.click(screen.getByTestId('selection-contextmenu-item-copy')) + expect(mockHandleNodesCopy).toHaveBeenCalledTimes(1) + expect(store.getState().selectionMenu).toBeUndefined() + + act(() => { + store.setState({ selectionMenu: { clientX: 120, clientY: 120 } }) + }) + fireEvent.click(screen.getByTestId('selection-contextmenu-item-duplicate')) + expect(mockHandleNodesDuplicate).toHaveBeenCalledTimes(1) + expect(store.getState().selectionMenu).toBeUndefined() + + act(() => { + store.setState({ selectionMenu: { clientX: 120, clientY: 120 } }) + }) + fireEvent.click(screen.getByTestId('selection-contextmenu-item-delete')) + expect(mockHandleNodesDelete).toHaveBeenCalledTimes(1) + expect(store.getState().selectionMenu).toBeUndefined() }) it('should close itself when only one node is selected', async () => { @@ -104,7 +150,7 @@ describe('SelectionContextmenu', () => { const { store } = renderSelectionMenu({ nodes }) act(() => { - store.setState({ selectionMenu: { left: 120, top: 120 } }) + store.setState({ selectionMenu: { clientX: 120, clientY: 120 } }) }) await waitFor(() => { @@ -129,7 +175,7 @@ describe('SelectionContextmenu', () => { }) act(() => { - store.setState({ selectionMenu: { left: 100, top: 100 } }) + store.setState({ selectionMenu: { clientX: 100, clientY: 100 } }) }) fireEvent.click(screen.getByTestId('selection-contextmenu-item-left')) @@ -162,7 +208,7 @@ describe('SelectionContextmenu', () => { }) act(() => { - store.setState({ selectionMenu: { left: 160, top: 120 } }) + store.setState({ selectionMenu: { clientX: 160, clientY: 120 } }) }) fireEvent.click(screen.getByTestId('selection-contextmenu-item-distributeHorizontal')) @@ -201,7 +247,7 @@ describe('SelectionContextmenu', () => { }) act(() => { - store.setState({ selectionMenu: { left: 180, top: 120 } }) + store.setState({ selectionMenu: { clientX: 180, clientY: 120 } }) }) fireEvent.click(screen.getByTestId('selection-contextmenu-item-left')) @@ -220,7 +266,7 @@ describe('SelectionContextmenu', () => { const { store } = renderSelectionMenu({ nodes }) act(() => { - store.setState({ selectionMenu: { left: 100, top: 100 } }) + store.setState({ selectionMenu: { clientX: 100, clientY: 100 } }) }) fireEvent.click(screen.getByTestId('selection-contextmenu-item-left')) @@ -238,7 +284,7 @@ describe('SelectionContextmenu', () => { const { store } = renderSelectionMenu({ nodes }) act(() => { - store.setState({ selectionMenu: { left: 100, top: 100 } }) + store.setState({ selectionMenu: { clientX: 100, clientY: 100 } }) }) fireEvent.click(screen.getByTestId('selection-contextmenu-item-left')) @@ -263,7 +309,7 @@ describe('SelectionContextmenu', () => { const { store } = renderSelectionMenu({ nodes }) act(() => { - store.setState({ selectionMenu: { left: 100, top: 100 } }) + store.setState({ selectionMenu: { clientX: 100, clientY: 100 } }) }) fireEvent.click(screen.getByTestId('selection-contextmenu-item-left')) diff --git a/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx b/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx index 82645f2028..961ab6ddb4 100644 --- a/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx +++ b/web/app/components/workflow/__tests__/update-dsl-modal.spec.tsx @@ -209,7 +209,7 @@ describe('UpdateDSLModal', () => { }) await waitFor(() => { - expect(screen.getByRole('button', { name: 'app.newApp.Cancel' })).toBeInTheDocument() + expect(screen.getByRole('button', { name: 'app.newApp.Confirm' })).toBeInTheDocument() }, { timeout: 1000 }) fireEvent.click(screen.getByRole('button', { name: 'app.newApp.Cancel' })) diff --git a/web/app/components/workflow/hooks/__tests__/use-panel-interactions.spec.ts b/web/app/components/workflow/hooks/__tests__/use-panel-interactions.spec.ts index 517af513b9..83c21fcb6a 100644 --- a/web/app/components/workflow/hooks/__tests__/use-panel-interactions.spec.ts +++ b/web/app/components/workflow/hooks/__tests__/use-panel-interactions.spec.ts @@ -29,7 +29,7 @@ describe('usePanelInteractions', () => { const { result, store } = renderWorkflowHook(() => usePanelInteractions(), { initialStoreState: { nodeMenu: { top: 20, left: 40, nodeId: 'n1' }, - selectionMenu: { top: 30, left: 50 }, + selectionMenu: { clientX: 30, clientY: 50 }, edgeMenu: { clientX: 320, clientY: 180, edgeId: 'e1' }, }, }) diff --git a/web/app/components/workflow/hooks/__tests__/use-selection-interactions.spec.ts b/web/app/components/workflow/hooks/__tests__/use-selection-interactions.spec.ts index 31d5d82475..5f584f33d7 100644 --- a/web/app/components/workflow/hooks/__tests__/use-selection-interactions.spec.ts +++ b/web/app/components/workflow/hooks/__tests__/use-selection-interactions.spec.ts @@ -200,8 +200,8 @@ describe('useSelectionInteractions', () => { }) expect(store.getState().selectionMenu).toEqual({ - top: 150, - left: 200, + clientX: 300, + clientY: 200, }) expect(store.getState().nodeMenu).toBeUndefined() expect(store.getState().panelMenu).toBeUndefined() @@ -210,7 +210,7 @@ describe('useSelectionInteractions', () => { it('handleSelectionContextmenuCancel should clear selectionMenu', () => { const { result, store } = renderSelectionInteractions({ - selectionMenu: { top: 50, left: 60 }, + selectionMenu: { clientX: 50, clientY: 60 }, }) act(() => { diff --git a/web/app/components/workflow/hooks/use-selection-interactions.ts b/web/app/components/workflow/hooks/use-selection-interactions.ts index 3c05d64cc4..793897a1af 100644 --- a/web/app/components/workflow/hooks/use-selection-interactions.ts +++ b/web/app/components/workflow/hooks/use-selection-interactions.ts @@ -137,15 +137,13 @@ export const useSelectionInteractions = () => { return e.preventDefault() - const container = document.querySelector('#workflow-container') - const { x, y } = container!.getBoundingClientRect() workflowStore.setState({ nodeMenu: undefined, panelMenu: undefined, edgeMenu: undefined, selectionMenu: { - top: e.clientY - y, - left: e.clientX - x, + clientX: e.clientX, + clientY: e.clientY, }, }) }, [workflowStore]) diff --git a/web/app/components/workflow/nodes/trigger-schedule/components/__tests__/integration.spec.tsx b/web/app/components/workflow/nodes/trigger-schedule/components/__tests__/integration.spec.tsx index 00a6cbbe29..f2441e78d0 100644 --- a/web/app/components/workflow/nodes/trigger-schedule/components/__tests__/integration.spec.tsx +++ b/web/app/components/workflow/nodes/trigger-schedule/components/__tests__/integration.spec.tsx @@ -1,6 +1,6 @@ /* eslint-disable ts/no-explicit-any */ import type { ScheduleTriggerNodeType } from '../../types' -import { fireEvent, render, screen, waitFor, within } from '@testing-library/react' +import { render, screen, waitFor, within } from '@testing-library/react' import userEvent from '@testing-library/user-event' import FrequencySelector from '../frequency-selector' import ModeSwitcher from '../mode-switcher' @@ -44,14 +44,14 @@ describe('trigger-schedule components', () => { ) const trigger = screen.getByRole('button', { name: 'workflow.nodes.triggerSchedule.frequency.daily' }) - fireEvent.click(trigger) + await user.click(trigger) await waitFor(() => { expect(trigger).toHaveAttribute('aria-expanded', 'true') }) const listbox = await screen.findByRole('listbox') - await user.click(within(listbox).getByText('workflow.nodes.triggerSchedule.frequency.weekly')) + await user.click(within(listbox).getByRole('option', { name: 'workflow.nodes.triggerSchedule.frequency.weekly' })) await waitFor(() => { expect(onChange).toHaveBeenCalledWith('weekly') diff --git a/web/app/components/workflow/note-node/types.ts b/web/app/components/workflow/note-node/types.ts index ad68bd0f10..e8a18589fd 100644 --- a/web/app/components/workflow/note-node/types.ts +++ b/web/app/components/workflow/note-node/types.ts @@ -1,13 +1,14 @@ import type { CommonNodeType } from '../types' -export enum NoteTheme { - blue = 'blue', - cyan = 'cyan', - green = 'green', - yellow = 'yellow', - pink = 'pink', - violet = 'violet', -} +export const NoteTheme = { + blue: 'blue', + cyan: 'cyan', + green: 'green', + yellow: 'yellow', + pink: 'pink', + violet: 'violet', +} as const +export type NoteTheme = typeof NoteTheme[keyof typeof NoteTheme] export type NoteNodeType = CommonNodeType & { text: string diff --git a/web/app/components/workflow/panel/chat-variable-panel/components/__tests__/variable-modal.spec.tsx b/web/app/components/workflow/panel/chat-variable-panel/components/__tests__/variable-modal.spec.tsx index 319e3803f4..297b534a6a 100644 --- a/web/app/components/workflow/panel/chat-variable-panel/components/__tests__/variable-modal.spec.tsx +++ b/web/app/components/workflow/panel/chat-variable-panel/components/__tests__/variable-modal.spec.tsx @@ -150,7 +150,7 @@ describe('variable-modal', () => { await user.click(screen.getByText('workflow.chatVariable.modal.editInJSON')) await waitFor(() => { - expect(screen.getByText('Loading...')).toBeInTheDocument() + expect(screen.getByTestId('monaco-editor')).toBeInTheDocument() }) await user.click(screen.getByText('workflow.chatVariable.modal.editInForm')) expect(screen.getByDisplayValue('enabled')).toBeInTheDocument() diff --git a/web/app/components/workflow/selection-contextmenu.tsx b/web/app/components/workflow/selection-contextmenu.tsx index 54e6ea2045..2b22df5012 100644 --- a/web/app/components/workflow/selection-contextmenu.tsx +++ b/web/app/components/workflow/selection-contextmenu.tsx @@ -1,13 +1,4 @@ -import type { ComponentType } from 'react' import type { Node } from './types' -import { - RiAlignBottom, - RiAlignCenter, - RiAlignJustify, - RiAlignLeft, - RiAlignRight, - RiAlignTop, -} from '@remixicon/react' import { produce } from 'immer' import { memo, @@ -24,11 +15,11 @@ import { ContextMenuGroupLabel, ContextMenuItem, ContextMenuSeparator, - ContextMenuTrigger, } from '@/app/components/base/ui/context-menu' -import { useNodesReadOnly, useNodesSyncDraft } from './hooks' +import { useNodesInteractions, useNodesReadOnly, useNodesSyncDraft } from './hooks' import { useSelectionInteractions } from './hooks/use-selection-interactions' import { useWorkflowHistory, WorkflowHistoryEvent } from './hooks/use-workflow-history' +import ShortcutsName from './shortcuts-name' import { useStore, useWorkflowStore } from './store' const AlignType = { @@ -44,13 +35,6 @@ const AlignType = { type AlignTypeValue = (typeof AlignType)[keyof typeof AlignType] -type SelectionMenuPosition = { - left: number - top: number -} - -type ContainerRect = Pick - type AlignBounds = { minX: number maxX: number @@ -60,7 +44,7 @@ type AlignBounds = { type MenuItem = { alignType: AlignTypeValue - icon: ComponentType<{ className?: string }> + icon: string iconClassName?: string translationKey: string } @@ -70,53 +54,27 @@ type MenuSection = { items: MenuItem[] } -const MENU_WIDTH = 240 -const MENU_HEIGHT = 380 - const menuSections: MenuSection[] = [ { titleKey: 'operator.vertical', items: [ - { alignType: AlignType.Top, icon: RiAlignTop, translationKey: 'operator.alignTop' }, - { alignType: AlignType.Middle, icon: RiAlignCenter, iconClassName: 'rotate-90', translationKey: 'operator.alignMiddle' }, - { alignType: AlignType.Bottom, icon: RiAlignBottom, translationKey: 'operator.alignBottom' }, - { alignType: AlignType.DistributeVertical, icon: RiAlignJustify, iconClassName: 'rotate-90', translationKey: 'operator.distributeVertical' }, + { alignType: AlignType.Top, icon: 'i-ri-align-top', translationKey: 'operator.alignTop' }, + { alignType: AlignType.Middle, icon: 'i-ri-align-center', iconClassName: 'rotate-90', translationKey: 'operator.alignMiddle' }, + { alignType: AlignType.Bottom, icon: 'i-ri-align-bottom', translationKey: 'operator.alignBottom' }, + { alignType: AlignType.DistributeVertical, icon: 'i-ri-align-justify', iconClassName: 'rotate-90', translationKey: 'operator.distributeVertical' }, ], }, { titleKey: 'operator.horizontal', items: [ - { alignType: AlignType.Left, icon: RiAlignLeft, translationKey: 'operator.alignLeft' }, - { alignType: AlignType.Center, icon: RiAlignCenter, translationKey: 'operator.alignCenter' }, - { alignType: AlignType.Right, icon: RiAlignRight, translationKey: 'operator.alignRight' }, - { alignType: AlignType.DistributeHorizontal, icon: RiAlignJustify, translationKey: 'operator.distributeHorizontal' }, + { alignType: AlignType.Left, icon: 'i-ri-align-left', translationKey: 'operator.alignLeft' }, + { alignType: AlignType.Center, icon: 'i-ri-align-center', translationKey: 'operator.alignCenter' }, + { alignType: AlignType.Right, icon: 'i-ri-align-right', translationKey: 'operator.alignRight' }, + { alignType: AlignType.DistributeHorizontal, icon: 'i-ri-align-justify', translationKey: 'operator.distributeHorizontal' }, ], }, ] -const getMenuPosition = ( - selectionMenu: SelectionMenuPosition | undefined, - containerRect?: ContainerRect | null, -) => { - if (!selectionMenu) - return { left: 0, top: 0 } - - let { left, top } = selectionMenu - - if (containerRect) { - if (left + MENU_WIDTH > containerRect.width) - left = left - MENU_WIDTH - - if (top + MENU_HEIGHT > containerRect.height) - top = top - MENU_HEIGHT - - left = Math.max(0, left) - top = Math.max(0, top) - } - - return { left, top } -} - const getAlignableNodes = (nodes: Node[], selectedNodes: Node[]) => { const selectedNodeIds = new Set(selectedNodes.map(node => node.id)) const childNodeIds = new Set() @@ -266,6 +224,7 @@ const SelectionContextmenu = () => { const { t } = useTranslation() const { getNodesReadOnly } = useNodesReadOnly() const { handleSelectionContextmenuCancel } = useSelectionInteractions() + const { handleNodesCopy, handleNodesDelete, handleNodesDuplicate } = useNodesInteractions() const selectionMenu = useStore(s => s.selectionMenu) const store = useStoreApi() const workflowStore = useWorkflowStore() @@ -275,9 +234,18 @@ const SelectionContextmenu = () => { const { handleSyncWorkflowDraft } = useNodesSyncDraft() const { saveStateToHistory } = useWorkflowHistory() - const menuPosition = useMemo(() => { - const container = document.querySelector('#workflow-container') - return getMenuPosition(selectionMenu, container?.getBoundingClientRect()) + const anchor = useMemo(() => { + if (!selectionMenu) + return undefined + + return { + getBoundingClientRect: () => DOMRect.fromRect({ + width: 0, + height: 0, + x: selectionMenu.clientX, + y: selectionMenu.clientY, + }), + } }, [selectionMenu]) useEffect(() => { @@ -285,6 +253,21 @@ const SelectionContextmenu = () => { handleSelectionContextmenuCancel() }, [selectionMenu, selectedNodes.length, handleSelectionContextmenuCancel]) + const handleCopyNodes = useCallback(() => { + handleNodesCopy() + handleSelectionContextmenuCancel() + }, [handleNodesCopy, handleSelectionContextmenuCancel]) + + const handleDuplicateNodes = useCallback(() => { + handleNodesDuplicate() + handleSelectionContextmenuCancel() + }, [handleNodesDuplicate, handleSelectionContextmenuCancel]) + + const handleDeleteNodes = useCallback(() => { + handleNodesDelete() + handleSelectionContextmenuCancel() + }, [handleNodesDelete, handleSelectionContextmenuCancel]) + const handleAlignNodes = useCallback((alignType: AlignTypeValue) => { if (getNodesReadOnly() || selectedNodes.length <= 1) { handleSelectionContextmenuCancel() @@ -352,49 +335,69 @@ const SelectionContextmenu = () => { return null return ( -
{ + if (!open) + handleSelectionContextmenuCancel() }} > - { - if (!open) - handleSelectionContextmenuCancel() - }} + - - - - - {menuSections.map((section, sectionIndex) => ( - - {sectionIndex > 0 && } - - {t(section.titleKey, { defaultValue: section.titleKey, ns: 'workflow' })} - - {section.items.map((item) => { - const Icon = item.icon - return ( - handleAlignNodes(item.alignType)} - > - - {t(item.translationKey, { defaultValue: item.translationKey, ns: 'workflow' })} - - ) - })} - - ))} - - -
+ + + {t('common.copy', { defaultValue: 'common.copy', ns: 'workflow' })} + + + + {t('common.duplicate', { defaultValue: 'common.duplicate', ns: 'workflow' })} + + + + + + + {t('operation.delete', { defaultValue: 'operation.delete', ns: 'common' })} + + + + + {menuSections.map((section, sectionIndex) => ( + + {sectionIndex > 0 && } + + {t(section.titleKey, { defaultValue: section.titleKey, ns: 'workflow' })} + + {section.items.map((item) => { + return ( + handleAlignNodes(item.alignType)} + > + + {t(item.translationKey, { defaultValue: item.translationKey, ns: 'workflow' })} + + ) + })} + + ))} + + ) } diff --git a/web/app/components/workflow/store/__tests__/workflow-store.spec.ts b/web/app/components/workflow/store/__tests__/workflow-store.spec.ts index df0288ac09..ee820b22bf 100644 --- a/web/app/components/workflow/store/__tests__/workflow-store.spec.ts +++ b/web/app/components/workflow/store/__tests__/workflow-store.spec.ts @@ -96,7 +96,7 @@ describe('createWorkflowStore', () => { ['showInputsPanel', 'setShowInputsPanel', true], ['showDebugAndPreviewPanel', 'setShowDebugAndPreviewPanel', true], ['panelMenu', 'setPanelMenu', { top: 10, left: 20 }], - ['selectionMenu', 'setSelectionMenu', { top: 50, left: 60 }], + ['selectionMenu', 'setSelectionMenu', { clientX: 50, clientY: 60 }], ['edgeMenu', 'setEdgeMenu', { clientX: 320, clientY: 180, edgeId: 'e1' }], ['showVariableInspectPanel', 'setShowVariableInspectPanel', true], ['initShowLastRunTab', 'setInitShowLastRunTab', true], diff --git a/web/app/components/workflow/store/workflow/panel-slice.ts b/web/app/components/workflow/store/workflow/panel-slice.ts index bf8b248c3a..83292ff77e 100644 --- a/web/app/components/workflow/store/workflow/panel-slice.ts +++ b/web/app/components/workflow/store/workflow/panel-slice.ts @@ -16,8 +16,8 @@ export type PanelSliceShape = { } setPanelMenu: (panelMenu: PanelSliceShape['panelMenu']) => void selectionMenu?: { - top: number - left: number + clientX: number + clientY: number } setSelectionMenu: (selectionMenu: PanelSliceShape['selectionMenu']) => void edgeMenu?: { diff --git a/web/app/components/workflow/variable-inspect/right.tsx b/web/app/components/workflow/variable-inspect/right.tsx index af939d41ea..893c477d04 100644 --- a/web/app/components/workflow/variable-inspect/right.tsx +++ b/web/app/components/workflow/variable-inspect/right.tsx @@ -174,7 +174,7 @@ const Right = ({ {currentNodeVar?.var && ( <> { - [VarInInspectType.environment, VarInInspectType.conversation, VarInInspectType.system].includes(currentNodeVar.nodeType as VarInInspectType) && ( + ([VarInInspectType.environment, VarInInspectType.conversation, VarInInspectType.system] as VarInInspectType[]).includes(currentNodeVar.nodeType as VarInInspectType) && ( }, -) { - const { owner, repo } = (await params) - try { - const releasesRes = await octokit.request('GET /repos/{owner}/{repo}/releases', { - owner, - repo, - headers: { - 'X-GitHub-Api-Version': '2022-11-28', - }, - }) - return NextResponse.json(releasesRes) - } - catch (error) { - if (error instanceof RequestError) - return NextResponse.json(error.response) - else - throw error - } -} diff --git a/web/config/index.ts b/web/config/index.ts index eed914726c..999aa754e6 100644 --- a/web/config/index.ts +++ b/web/config/index.ts @@ -292,9 +292,6 @@ export const resetHITLInputReg = () => HITL_INPUT_REG.lastIndex = 0 export const DISABLE_UPLOAD_IMAGE_AS_ICON = env.NEXT_PUBLIC_DISABLE_UPLOAD_IMAGE_AS_ICON -export const GITHUB_ACCESS_TOKEN - = env.NEXT_PUBLIC_GITHUB_ACCESS_TOKEN - export const SUPPORT_INSTALL_LOCAL_FILE_EXTENSIONS = '.difypkg,.difybndl' export const FULL_DOC_PREVIEW_LENGTH = 50 diff --git a/web/docs/test.md b/web/docs/test.md index 2facdbb29f..bc1546a991 100644 --- a/web/docs/test.md +++ b/web/docs/test.md @@ -8,7 +8,7 @@ When I ask you to write/refactor/fix tests, follow these rules by default. - **Framework**: Next.js 15 + React 19 + TypeScript - **Testing Tools**: Vitest 4.0.16 + React Testing Library 16.0 -- **Test Environment**: jsdom +- **Test Environment**: happy-dom - **File Naming**: `ComponentName.spec.tsx` inside a same-level `__tests__/` directory - **Placement Rule**: Component, hook, and utility tests must live in a sibling `__tests__/` folder at the same level as the source under test. For example, `foo/index.tsx` maps to `foo/__tests__/index.spec.tsx`, and `foo/bar.ts` maps to `foo/__tests__/bar.spec.ts`. @@ -30,7 +30,7 @@ pnpm test path/to/file.spec.tsx ## Project Test Setup -- **Configuration**: `vitest.config.ts` sets the `jsdom` environment, loads the Testing Library presets, and respects our path aliases (`@/...`). Check this file before adding new transformers or module name mappers. +- **Configuration**: `vite.config.ts` sets the `happy-dom` environment, loads the Testing Library presets, and respects our path aliases (`@/...`). Check this file before adding new transformers or module name mappers. - **Global setup**: `vitest.setup.ts` already imports `@testing-library/jest-dom`, runs `cleanup()` after every test, and defines shared mocks (for example `react-i18next`). Add any environment-level mocks (for example `ResizeObserver`, `matchMedia`, `IntersectionObserver`, `TextEncoder`, `crypto`) here so they are shared consistently. - **Reusable mocks**: Place shared mock factories inside `web/__mocks__/` and use `vi.mock('module-name')` to point to them rather than redefining mocks in every spec. - **Mocking behavior**: Modules are not mocked automatically. Use `vi.mock(...)` in tests, or place global mocks in `vitest.setup.ts`. @@ -283,16 +283,6 @@ Reserve snapshots for static, deterministic fragments (icons, badges, layout chr **Note**: Dify is a desktop application. **No need for** responsive/mobile testing. -### 12. Mock API - -Use Nock to mock API calls. Example: - -```ts -const mockGithubStar = (status: number, body: Record, delayMs = 0) => { - return nock(GITHUB_HOST).get(GITHUB_PATH).delay(delayMs).reply(status, body) -} -``` - ## Code Style ### Example Structure diff --git a/web/env.ts b/web/env.ts index 8ecde76143..55709918a8 100644 --- a/web/env.ts +++ b/web/env.ts @@ -66,10 +66,6 @@ const clientSchema = { NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL: coercedBoolean.default(true), NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER: coercedBoolean.default(true), NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL: coercedBoolean.default(false), - /** - * Github Access Token, used for invoking Github API - */ - NEXT_PUBLIC_GITHUB_ACCESS_TOKEN: z.string().optional(), /** * The maximum number of tokens for segmentation */ @@ -171,7 +167,6 @@ export const env = createEnv({ NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL: isServer ? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_FIRECRAWL : getRuntimeEnvFromBody('enableWebsiteFirecrawl'), NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER: isServer ? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_JINAREADER : getRuntimeEnvFromBody('enableWebsiteJinareader'), NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL: isServer ? process.env.NEXT_PUBLIC_ENABLE_WEBSITE_WATERCRAWL : getRuntimeEnvFromBody('enableWebsiteWatercrawl'), - NEXT_PUBLIC_GITHUB_ACCESS_TOKEN: isServer ? process.env.NEXT_PUBLIC_GITHUB_ACCESS_TOKEN : getRuntimeEnvFromBody('githubAccessToken'), NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH: isServer ? process.env.NEXT_PUBLIC_INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH : getRuntimeEnvFromBody('indexingMaxSegmentationTokensLength'), NEXT_PUBLIC_IS_MARKETPLACE: isServer ? process.env.NEXT_PUBLIC_IS_MARKETPLACE : getRuntimeEnvFromBody('isMarketplace'), NEXT_PUBLIC_LOOP_NODE_MAX_COUNT: isServer ? process.env.NEXT_PUBLIC_LOOP_NODE_MAX_COUNT : getRuntimeEnvFromBody('loopNodeMaxCount'), diff --git a/web/eslint-suppressions.json b/web/eslint-suppressions.json index e28d915e66..e34300993a 100644 --- a/web/eslint-suppressions.json +++ b/web/eslint-suppressions.json @@ -4,12 +4,9 @@ "count": 1 } }, - "__tests__/check-i18n.test.ts": { - "regexp/no-unused-capturing-group": { + "__mocks__/zustand.ts": { + "no-barrel-files/no-barrel-files": { "count": 1 - }, - "ts/no-explicit-any": { - "count": 2 } }, "__tests__/document-detail-navigation-fix.test.tsx": { @@ -1496,6 +1493,11 @@ "count": 2 } }, + "app/components/base/amplitude/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/base/amplitude/utils.ts": { "ts/no-explicit-any": { "count": 2 @@ -1855,6 +1857,11 @@ "count": 3 } }, + "app/components/base/chat/types.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, "app/components/base/chat/utils.ts": { "ts/no-explicit-any": { "count": 10 @@ -1937,6 +1944,11 @@ "count": 4 } }, + "app/components/base/date-and-time-picker/hooks.ts": { + "react/no-unnecessary-use-prefix": { + "count": 2 + } + }, "app/components/base/date-and-time-picker/time-picker/header.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -1957,6 +1969,11 @@ "count": 2 } }, + "app/components/base/date-and-time-picker/utils/dayjs.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, "app/components/base/date-and-time-picker/year-and-month-picker/header.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -2018,6 +2035,11 @@ "count": 1 } }, + "app/components/base/features/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, "app/components/base/features/new-feature-panel/annotation-reply/annotation-ctrl-button.tsx": { "no-restricted-imports": { "count": 2 @@ -2184,10 +2206,23 @@ "no-restricted-imports": { "count": 1 }, + "react/no-unnecessary-use-prefix": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } }, + "app/components/base/file-uploader/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 7 + } + }, + "app/components/base/file-uploader/pdf-highlighter-adapter.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/base/file-uploader/pdf-preview.tsx": { "no-restricted-imports": { "count": 1 @@ -2221,6 +2256,11 @@ "count": 6 } }, + "app/components/base/form/components/base/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/base/form/components/field/checkbox.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -2315,6 +2355,11 @@ "count": 2 } }, + "app/components/base/form/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, "app/components/base/form/hooks/use-check-validated.ts": { "no-restricted-imports": { "count": 1 @@ -2349,6 +2394,271 @@ "count": 4 } }, + "app/components/base/icons/src/image/llm/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 9 + } + }, + "app/components/base/icons/src/public/avatar/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, + "app/components/base/icons/src/public/billing/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 12 + } + }, + "app/components/base/icons/src/public/common/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 16 + } + }, + "app/components/base/icons/src/public/education/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/public/files/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 11 + } + }, + "app/components/base/icons/src/public/knowledge/dataset-card/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, + "app/components/base/icons/src/public/knowledge/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 8 + } + }, + "app/components/base/icons/src/public/knowledge/online-drive/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, + "app/components/base/icons/src/public/llm/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 50 + } + }, + "app/components/base/icons/src/public/model/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/public/other/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 6 + } + }, + "app/components/base/icons/src/public/plugins/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 7 + } + }, + "app/components/base/icons/src/public/thought/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, + "app/components/base/icons/src/public/tracing/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 21 + } + }, + "app/components/base/icons/src/vender/features/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 10 + } + }, + "app/components/base/icons/src/vender/knowledge/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 16 + } + }, + "app/components/base/icons/src/vender/line/alertsAndFeedback/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, + "app/components/base/icons/src/vender/line/arrows/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 9 + } + }, + "app/components/base/icons/src/vender/line/communication/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 6 + } + }, + "app/components/base/icons/src/vender/line/development/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 14 + } + }, + "app/components/base/icons/src/vender/line/editor/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 8 + } + }, + "app/components/base/icons/src/vender/line/education/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/vender/line/files/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 11 + } + }, + "app/components/base/icons/src/vender/line/financeAndECommerce/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 7 + } + }, + "app/components/base/icons/src/vender/line/general/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 30 + } + }, + "app/components/base/icons/src/vender/line/images/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/vender/line/layout/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, + "app/components/base/icons/src/vender/line/mediaAndDevices/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 6 + } + }, + "app/components/base/icons/src/vender/line/others/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 10 + } + }, + "app/components/base/icons/src/vender/line/shapes/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/vender/line/time/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, + "app/components/base/icons/src/vender/line/users/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, + "app/components/base/icons/src/vender/line/weather/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/vender/other/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 9 + } + }, + "app/components/base/icons/src/vender/pipeline/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, + "app/components/base/icons/src/vender/plugin/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, + "app/components/base/icons/src/vender/solid/FinanceAndECommerce/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, + "app/components/base/icons/src/vender/solid/alertsAndFeedback/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/vender/solid/arrows/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, + "app/components/base/icons/src/vender/solid/communication/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 12 + } + }, + "app/components/base/icons/src/vender/solid/development/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 13 + } + }, + "app/components/base/icons/src/vender/solid/editor/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, + "app/components/base/icons/src/vender/solid/education/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, + "app/components/base/icons/src/vender/solid/files/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, + "app/components/base/icons/src/vender/solid/general/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 18 + } + }, + "app/components/base/icons/src/vender/solid/layout/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/vender/solid/mediaAndDevices/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 12 + } + }, + "app/components/base/icons/src/vender/solid/security/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/vender/solid/shapes/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, + "app/components/base/icons/src/vender/solid/users/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, + "app/components/base/icons/src/vender/system/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/base/icons/src/vender/workflow/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 31 + } + }, "app/components/base/icons/utils.ts": { "ts/no-explicit-any": { "count": 3 @@ -2459,6 +2769,11 @@ "count": 3 } }, + "app/components/base/markdown-blocks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 11 + } + }, "app/components/base/markdown-blocks/link.tsx": { "ts/no-explicit-any": { "count": 1 @@ -2586,6 +2901,11 @@ "count": 2 } }, + "app/components/base/notion-page-selector/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, "app/components/base/notion-page-selector/page-selector/index.tsx": { "react/set-state-in-effect": { "count": 1 @@ -2651,6 +2971,9 @@ } }, "app/components/base/prompt-editor/plugins/context-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 2 } @@ -2661,6 +2984,9 @@ } }, "app/components/base/prompt-editor/plugins/current-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 2 } @@ -2676,6 +3002,9 @@ } }, "app/components/base/prompt-editor/plugins/error-message-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 2 } @@ -2686,6 +3015,9 @@ } }, "app/components/base/prompt-editor/plugins/history-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 2 } @@ -2701,6 +3033,9 @@ } }, "app/components/base/prompt-editor/plugins/hitl-input-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 3 } @@ -2726,11 +3061,17 @@ } }, "app/components/base/prompt-editor/plugins/last-run-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 2 } }, "app/components/base/prompt-editor/plugins/query-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 2 } @@ -2741,6 +3082,9 @@ } }, "app/components/base/prompt-editor/plugins/request-url-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 2 } @@ -2767,8 +3111,11 @@ } }, "app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { - "count": 4 + "count": 3 } }, "app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx": { @@ -2950,6 +3297,11 @@ "count": 1 } }, + "app/components/base/text-generation/types.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, "app/components/base/textarea/index.stories.tsx": { "no-console": { "count": 1 @@ -2963,6 +3315,11 @@ "count": 1 } }, + "app/components/base/toast/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/base/video-gallery/VideoPlayer.tsx": { "react/set-state-in-effect": { "count": 1 @@ -3025,6 +3382,11 @@ "count": 1 } }, + "app/components/billing/plan/assets/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, "app/components/billing/plan/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 2 @@ -3033,6 +3395,11 @@ "count": 2 } }, + "app/components/billing/pricing/assets/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 12 + } + }, "app/components/billing/pricing/plan-switcher/plan-range-switcher.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -3211,6 +3578,9 @@ } }, "app/components/datasets/create-from-pipeline/create-options/create-from-dsl-modal/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 1 + }, "no-restricted-imports": { "count": 1 }, @@ -3332,6 +3702,16 @@ "count": 1 } }, + "app/components/datasets/create/step-one/components/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, + "app/components/datasets/create/step-one/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/datasets/create/step-one/index.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -3355,6 +3735,11 @@ "count": 1 } }, + "app/components/datasets/create/step-two/components/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, "app/components/datasets/create/step-two/components/indexing-mode-section.tsx": { "no-restricted-imports": { "count": 2 @@ -3365,6 +3750,11 @@ "count": 1 } }, + "app/components/datasets/create/step-two/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 10 + } + }, "app/components/datasets/create/step-two/hooks/use-document-creation.ts": { "no-restricted-imports": { "count": 1 @@ -3379,6 +3769,9 @@ } }, "app/components/datasets/create/step-two/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 1 + }, "no-restricted-imports": { "count": 1 }, @@ -3537,6 +3930,21 @@ "count": 1 } }, + "app/components/datasets/documents/components/document-list/components/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, + "app/components/datasets/documents/components/document-list/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, + "app/components/datasets/documents/components/document-list/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/datasets/documents/components/documents-header.tsx": { "no-restricted-imports": { "count": 1 @@ -3727,6 +4135,11 @@ "count": 2 } }, + "app/components/datasets/documents/create-from-pipeline/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, "app/components/datasets/documents/create-from-pipeline/left-header.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 3 @@ -3781,6 +4194,11 @@ "count": 3 } }, + "app/components/datasets/documents/create-from-pipeline/steps/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, "app/components/datasets/documents/create-from-pipeline/types.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -3878,6 +4296,11 @@ "count": 2 } }, + "app/components/datasets/documents/detail/completed/components/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, "app/components/datasets/documents/detail/completed/components/menu-bar.tsx": { "no-restricted-imports": { "count": 2 @@ -3891,6 +4314,11 @@ "count": 1 } }, + "app/components/datasets/documents/detail/completed/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 10 + } + }, "app/components/datasets/documents/detail/completed/hooks/use-child-segment-data.ts": { "no-restricted-imports": { "count": 1 @@ -3907,6 +4335,9 @@ } }, "app/components/datasets/documents/detail/completed/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + }, "react-refresh/only-export-components": { "count": 1 }, @@ -3947,6 +4378,11 @@ "count": 1 } }, + "app/components/datasets/documents/detail/embedding/components/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 4 + } + }, "app/components/datasets/documents/detail/embedding/components/segment-progress.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -3957,6 +4393,11 @@ "count": 3 } }, + "app/components/datasets/documents/detail/embedding/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/datasets/documents/detail/embedding/index.tsx": { "no-restricted-imports": { "count": 1 @@ -3988,6 +4429,11 @@ "count": 4 } }, + "app/components/datasets/documents/detail/metadata/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, "app/components/datasets/documents/detail/segment-add/index.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -4564,6 +5010,11 @@ "count": 2 } }, + "app/components/goto-anything/actions/commands/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, "app/components/goto-anything/actions/commands/registry.ts": { "ts/no-explicit-any": { "count": 3 @@ -4582,6 +5033,11 @@ "count": 1 } }, + "app/components/goto-anything/actions/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 6 + } + }, "app/components/goto-anything/actions/types.ts": { "ts/no-explicit-any": { "count": 2 @@ -4597,6 +5053,11 @@ "count": 1 } }, + "app/components/goto-anything/components/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 10 + } + }, "app/components/goto-anything/context.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -4605,6 +5066,11 @@ "count": 4 } }, + "app/components/goto-anything/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 8 + } + }, "app/components/goto-anything/hooks/use-goto-anything-results.ts": { "@tanstack/query/exhaustive-deps": { "count": 1 @@ -4674,6 +5140,11 @@ "count": 1 } }, + "app/components/header/account-setting/data-source-page-new/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/header/account-setting/data-source-page-new/hooks/use-marketplace-all-plugins.ts": { "ts/no-explicit-any": { "count": 1 @@ -4766,6 +5237,9 @@ } }, "app/components/header/account-setting/model-provider-page/hooks.ts": { + "react/no-unnecessary-use-prefix": { + "count": 1 + }, "ts/no-explicit-any": { "count": 2 } @@ -4811,6 +5285,11 @@ "count": 1 } }, + "app/components/header/account-setting/model-provider-page/model-auth/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 6 + } + }, "app/components/header/account-setting/model-provider-page/model-auth/hooks/use-auth.ts": { "no-restricted-imports": { "count": 1 @@ -4819,11 +5298,21 @@ "count": 6 } }, + "app/components/header/account-setting/model-provider-page/model-auth/hooks/use-custom-models.ts": { + "react/no-unnecessary-use-prefix": { + "count": 2 + } + }, "app/components/header/account-setting/model-provider-page/model-auth/hooks/use-model-form-schemas.ts": { "ts/no-explicit-any": { "count": 2 } }, + "app/components/header/account-setting/model-provider-page/model-auth/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 8 + } + }, "app/components/header/account-setting/model-provider-page/model-auth/switch-credential-in-load-balancing.tsx": { "no-restricted-imports": { "count": 1 @@ -4936,6 +5425,11 @@ "count": 1 } }, + "app/components/header/account-setting/model-provider-page/utils.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/header/account-setting/plugin-page/SerpapiPlugin.tsx": { "no-restricted-imports": { "count": 1 @@ -5038,6 +5532,11 @@ "count": 4 } }, + "app/components/plugins/install-plugin/hooks/use-fold-anim-into.ts": { + "react/no-unnecessary-use-prefix": { + "count": 1 + } + }, "app/components/plugins/install-plugin/install-bundle/index.tsx": { "erasable-syntax-only/enums": { "count": 1 @@ -5242,6 +5741,11 @@ "count": 1 } }, + "app/components/plugins/plugin-auth/hooks/use-get-api.ts": { + "react/no-unnecessary-use-prefix": { + "count": 1 + } + }, "app/components/plugins/plugin-auth/hooks/use-plugin-auth-action.ts": { "no-restricted-imports": { "count": 1 @@ -5251,6 +5755,9 @@ } }, "app/components/plugins/plugin-auth/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 12 + }, "react-refresh/only-export-components": { "count": 3 } @@ -5264,6 +5771,9 @@ "erasable-syntax-only/enums": { "count": 2 }, + "no-barrel-files/no-barrel-files": { + "count": 2 + }, "ts/no-explicit-any": { "count": 1 } @@ -5325,6 +5835,16 @@ "count": 1 } }, + "app/components/plugins/plugin-detail-panel/detail-header.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "app/components/plugins/plugin-detail-panel/detail-header/components/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/plugins/plugin-detail-panel/detail-header/components/plugin-source-badge.tsx": { "no-restricted-imports": { "count": 1 @@ -5333,6 +5853,11 @@ "count": 1 } }, + "app/components/plugins/plugin-detail-panel/detail-header/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, "app/components/plugins/plugin-detail-panel/endpoint-card.tsx": { "no-restricted-imports": { "count": 2 @@ -5425,6 +5950,9 @@ } }, "app/components/plugins/plugin-detail-panel/subscription-list/create/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "no-restricted-imports": { "count": 3 } @@ -5461,6 +5989,9 @@ } }, "app/components/plugins/plugin-detail-panel/subscription-list/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + }, "react-refresh/only-export-components": { "count": 1 } @@ -5513,6 +6044,11 @@ "count": 1 } }, + "app/components/plugins/plugin-detail-panel/tool-selector/components/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 8 + } + }, "app/components/plugins/plugin-detail-panel/tool-selector/components/reasoning-config-form.tsx": { "no-restricted-imports": { "count": 2 @@ -5549,6 +6085,11 @@ "count": 2 } }, + "app/components/plugins/plugin-detail-panel/tool-selector/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, "app/components/plugins/plugin-detail-panel/tool-selector/index.tsx": { "no-restricted-imports": { "count": 1 @@ -5769,11 +6310,6 @@ "count": 1 } }, - "app/components/rag-pipeline/components/chunk-card-list/types.ts": { - "erasable-syntax-only/enums": { - "count": 1 - } - }, "app/components/rag-pipeline/components/conversion.tsx": { "no-restricted-imports": { "count": 2 @@ -5931,11 +6467,6 @@ "count": 1 } }, - "app/components/rag-pipeline/components/panel/test-run/types.ts": { - "erasable-syntax-only/enums": { - "count": 1 - } - }, "app/components/rag-pipeline/components/publish-as-knowledge-pipeline-modal.tsx": { "no-restricted-imports": { "count": 1 @@ -5996,6 +6527,11 @@ "count": 2 } }, + "app/components/rag-pipeline/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 9 + } + }, "app/components/rag-pipeline/hooks/use-DSL.ts": { "no-restricted-imports": { "count": 1 @@ -6039,6 +6575,11 @@ "count": 2 } }, + "app/components/rag-pipeline/utils/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, "app/components/rag-pipeline/utils/nodes.ts": { "ts/no-explicit-any": { "count": 1 @@ -6392,6 +6933,11 @@ "count": 2 } }, + "app/components/workflow-app/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 13 + } + }, "app/components/workflow-app/hooks/use-DSL.ts": { "no-restricted-imports": { "count": 1 @@ -6435,6 +6981,11 @@ "count": 2 } }, + "app/components/workflow/__tests__/fixtures.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/workflow/block-selector/all-start-blocks.tsx": { "react/set-state-in-effect": { "count": 1 @@ -6709,6 +7260,11 @@ "count": 1 } }, + "app/components/workflow/hooks-store/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/workflow/hooks-store/provider.tsx": { "react-refresh/only-export-components": { "count": 1 @@ -6724,6 +7280,11 @@ "count": 1 } }, + "app/components/workflow/hooks/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 26 + } + }, "app/components/workflow/hooks/use-checklist.ts": { "ts/no-empty-object-type": { "count": 1 @@ -6767,6 +7328,16 @@ "count": 1 } }, + "app/components/workflow/hooks/use-workflow-interactions.ts": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, + "app/components/workflow/hooks/use-workflow-run-event/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 19 + } + }, "app/components/workflow/hooks/use-workflow-run-event/use-workflow-agent-log.ts": { "ts/no-explicit-any": { "count": 1 @@ -6849,6 +7420,11 @@ "count": 1 } }, + "app/components/workflow/nodes/_base/components/collapse/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, "app/components/workflow/nodes/_base/components/config-vision.tsx": { "no-restricted-imports": { "count": 1 @@ -6992,6 +7568,9 @@ } }, "app/components/workflow/nodes/_base/components/layout/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 7 + }, "react-refresh/only-export-components": { "count": 7 } @@ -7229,6 +7808,16 @@ "count": 2 } }, + "app/components/workflow/nodes/_base/components/variable/variable-label/hooks.ts": { + "react/no-unnecessary-use-prefix": { + "count": 2 + } + }, + "app/components/workflow/nodes/_base/components/variable/variable-label/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 5 + } + }, "app/components/workflow/nodes/_base/components/workflow-panel/index.tsx": { "no-restricted-imports": { "count": 1 @@ -7260,6 +7849,9 @@ } }, "app/components/workflow/nodes/_base/components/workflow-panel/last-run/use-last-run.ts": { + "react/no-unnecessary-use-prefix": { + "count": 2 + }, "react/set-state-in-effect": { "count": 1 }, @@ -7487,6 +8079,9 @@ "erasable-syntax-only/enums": { "count": 1 }, + "no-barrel-files/no-barrel-files": { + "count": 1 + }, "ts/no-explicit-any": { "count": 1 } @@ -7894,6 +8489,9 @@ "app/components/workflow/nodes/knowledge-base/types.ts": { "erasable-syntax-only/enums": { "count": 1 + }, + "no-barrel-files/no-barrel-files": { + "count": 8 } }, "app/components/workflow/nodes/knowledge-base/use-single-run-form-params.ts": { @@ -8124,6 +8722,11 @@ "count": 1 } }, + "app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/assets/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/workflow/nodes/llm/components/json-schema-config-modal/json-schema-generator/generated-result.tsx": { "style/multiline-ternary": { "count": 2 @@ -8551,6 +9154,9 @@ } }, "app/components/workflow/nodes/tool/types.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } @@ -8590,6 +9196,9 @@ } }, "app/components/workflow/nodes/trigger-plugin/types.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + }, "ts/no-explicit-any": { "count": 4 } @@ -8678,6 +9287,11 @@ "count": 5 } }, + "app/components/workflow/note-node/note-editor/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 3 + } + }, "app/components/workflow/note-node/note-editor/plugins/link-editor-plugin/component.tsx": { "react/set-state-in-effect": { "count": 1 @@ -8719,11 +9333,6 @@ "count": 1 } }, - "app/components/workflow/note-node/types.ts": { - "erasable-syntax-only/enums": { - "count": 1 - } - }, "app/components/workflow/operator/add-block.tsx": { "ts/no-explicit-any": { "count": 1 @@ -9020,6 +9629,11 @@ "count": 1 } }, + "app/components/workflow/run/agent-log/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/workflow/run/hooks.ts": { "ts/no-explicit-any": { "count": 1 @@ -9030,6 +9644,11 @@ "count": 2 } }, + "app/components/workflow/run/iteration-log/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/workflow/run/iteration-log/iteration-log-trigger.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -9043,6 +9662,11 @@ "count": 2 } }, + "app/components/workflow/run/loop-log/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/workflow/run/loop-log/loop-log-trigger.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -9101,6 +9725,11 @@ "count": 2 } }, + "app/components/workflow/run/retry-log/index.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/workflow/run/retry-log/retry-result-panel.tsx": { "tailwindcss/enforce-consistent-class-order": { "count": 1 @@ -9154,6 +9783,11 @@ "count": 1 } }, + "app/components/workflow/store/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "app/components/workflow/store/workflow/debug/inspect-vars-slice.ts": { "ts/no-explicit-any": { "count": 2 @@ -9198,6 +9832,11 @@ "count": 1 } }, + "app/components/workflow/utils/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 10 + } + }, "app/components/workflow/utils/node-navigation.ts": { "ts/no-explicit-any": { "count": 2 @@ -9301,11 +9940,6 @@ "count": 1 } }, - "app/components/workflow/variable-inspect/types.ts": { - "erasable-syntax-only/enums": { - "count": 2 - } - }, "app/components/workflow/variable-inspect/utils.tsx": { "ts/no-explicit-any": { "count": 2 @@ -9556,6 +10190,9 @@ } }, "context/modal-context-provider.tsx": { + "no-barrel-files/no-barrel-files": { + "count": 1 + }, "ts/no-explicit-any": { "count": 3 } @@ -9632,6 +10269,16 @@ "count": 4 } }, + "i18n-config/index.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, + "i18n-config/lib.client.ts": { + "no-barrel-files/no-barrel-files": { + "count": 1 + } + }, "i18n/de-DE/billing.json": { "no-irregular-whitespace": { "count": 1 @@ -9809,6 +10456,11 @@ "count": 3 } }, + "service/try-app.ts": { + "no-barrel-files/no-barrel-files": { + "count": 2 + } + }, "service/use-apps.ts": { "ts/no-explicit-any": { "count": 1 @@ -9827,6 +10479,11 @@ "count": 7 } }, + "service/use-flow.ts": { + "react/no-unnecessary-use-prefix": { + "count": 1 + } + }, "service/use-pipeline.ts": { "@tanstack/query/exhaustive-deps": { "count": 2 @@ -9902,11 +10559,6 @@ "count": 3 } }, - "types/model-provider.ts": { - "erasable-syntax-only/enums": { - "count": 1 - } - }, "types/pipeline.tsx": { "ts/no-explicit-any": { "count": 3 @@ -9918,9 +10570,6 @@ } }, "types/workflow.ts": { - "erasable-syntax-only/enums": { - "count": 1 - }, "ts/no-explicit-any": { "count": 17 } diff --git a/web/eslint.config.mjs b/web/eslint.config.mjs index d5d833ba69..940aad9f4c 100644 --- a/web/eslint.config.mjs +++ b/web/eslint.config.mjs @@ -7,6 +7,7 @@ import md from 'eslint-markdown' import tailwindcss from 'eslint-plugin-better-tailwindcss' import hyoban from 'eslint-plugin-hyoban' import markdownPreferences from 'eslint-plugin-markdown-preferences' +import noBarrelFiles from 'eslint-plugin-no-barrel-files' import { reactRefresh } from 'eslint-plugin-react-refresh' import sonar from 'eslint-plugin-sonarjs' import storybook from 'eslint-plugin-storybook' @@ -30,12 +31,17 @@ const plugins = pluginReact.configs.all.plugins export default antfu( { react: false, - nextjs: true, + nextjs: { + overrides: { + 'next/no-img-element': 'off', + }, + }, ignores: ['public', 'types/doc-paths.ts', 'eslint-suppressions.json'], typescript: { overrides: { 'ts/consistent-type-definitions': ['error', 'type'], 'ts/no-explicit-any': 'error', + 'ts/no-redeclare': 'off', }, erasableOnly: true, }, @@ -66,12 +72,23 @@ export default antfu( ...pluginReact.configs['recommended-typescript'].rules, 'react/prefer-namespace-import': 'error', 'react/set-state-in-effect': 'error', + 'react/no-unnecessary-use-prefix': 'error', }, }, { files: [...GLOB_TESTS, GLOB_MARKDOWN_CODE, 'vitest.setup.ts', 'test/i18n-mock.ts'], rules: { 'react/component-hook-factories': 'off', + 'react/no-unnecessary-use-prefix': 'off', + }, + }, + { + plugins: { + 'no-barrel-files': noBarrelFiles, + }, + ignores: ['next/**'], + rules: { + 'no-barrel-files/no-barrel-files': 'error', }, }, reactRefresh.configs.next(), @@ -98,7 +115,6 @@ export default antfu( { rules: { 'node/prefer-global/process': 'off', - 'next/no-img-element': 'off', }, }, { @@ -160,7 +176,7 @@ export default antfu( }, }, { - files: ['**/package.json'], + files: ['package.json'], rules: { 'hyoban/no-dependency-version-prefix': 'error', }, diff --git a/web/i18n/ar-TN/app-debug.json b/web/i18n/ar-TN/app-debug.json index c35983397a..99c510dcb5 100644 --- a/web/i18n/ar-TN/app-debug.json +++ b/web/i18n/ar-TN/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "تشغيل", "inputs.title": "تصحيح ومعاينة", "inputs.userInputField": "حقل إدخال المستخدم", + "manageModels": "إدارة النماذج", "modelConfig.modeType.chat": "دردشة", "modelConfig.modeType.completion": "إكمال", "modelConfig.model": "نموذج", "modelConfig.setTone": "تعيين نبرة الاستجابات", "modelConfig.title": "النموذج والمعلمات", + "noModelProviderConfigured": "لم يتم تكوين أي مزوّد نماذج", + "noModelProviderConfiguredTip": "قم بتثبيت أو تكوين مزوّد نماذج للبدء.", + "noModelSelected": "لم يتم اختيار أي نموذج", + "noModelSelectedTip": "قم بتكوين نموذج أعلاه للمتابعة.", "noResult": "سيتم عرض الإخراج هنا.", "notSetAPIKey.description": "لم يتم تعيين مفتاح مزود LLM، ويجب تعيينه قبل تصحيح الأخطاء.", "notSetAPIKey.settingBtn": "الذهاب إلى الإعدادات", diff --git a/web/i18n/ar-TN/common.json b/web/i18n/ar-TN/common.json index 74fe80bcff..3bc7c05564 100644 --- a/web/i18n/ar-TN/common.json +++ b/web/i18n/ar-TN/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "غير مصرح به", "modelProvider.buyQuota": "شراء حصة", "modelProvider.callTimes": "أوقات الاتصال", + "modelProvider.card.aiCreditsInUse": "أرصدة الذكاء الاصطناعي قيد الاستخدام", + "modelProvider.card.aiCreditsOption": "أرصدة الذكاء الاصطناعي", + "modelProvider.card.apiKeyOption": "مفتاح API", + "modelProvider.card.apiKeyRequired": "مفتاح API مطلوب", + "modelProvider.card.apiKeyUnavailableFallback": "مفتاح API غير متاح، يتم الآن استخدام أرصدة الذكاء الاصطناعي", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "تحقق من تكوين مفتاح API الخاص بك للتبديل مرة أخرى", "modelProvider.card.buyQuota": "شراء حصة", "modelProvider.card.callTimes": "أوقات الاتصال", + "modelProvider.card.creditsExhaustedDescription": "يرجى ترقية خطتك أو تكوين مفتاح API", + "modelProvider.card.creditsExhaustedFallback": "نفدت أرصدة الذكاء الاصطناعي، يتم الآن استخدام مفتاح API", + "modelProvider.card.creditsExhaustedFallbackDescription": "قم بترقية خطتك لاستئناف أولوية أرصدة الذكاء الاصطناعي.", + "modelProvider.card.creditsExhaustedMessage": "نفدت أرصدة الذكاء الاصطناعي", "modelProvider.card.modelAPI": "نماذج {{modelName}} تستخدم مفتاح API.", "modelProvider.card.modelNotSupported": "نماذج {{modelName}} غير مثبتة.", "modelProvider.card.modelSupported": "نماذج {{modelName}} تستخدم هذا الحصة.", + "modelProvider.card.noApiKeysDescription": "أضف مفتاح API لبدء استخدام بيانات اعتماد النموذج الخاصة بك.", + "modelProvider.card.noApiKeysFallback": "لا توجد مفاتيح API، يتم استخدام أرصدة الذكاء الاصطناعي بدلاً من ذلك", + "modelProvider.card.noApiKeysTitle": "لم يتم تكوين أي مفاتيح API بعد", + "modelProvider.card.noAvailableUsage": "لا يوجد استخدام متاح", "modelProvider.card.onTrial": "في التجربة", "modelProvider.card.paid": "مدفوع", "modelProvider.card.priorityUse": "أولوية الاستخدام", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "إزالة مفتاح API", "modelProvider.card.tip": "تدعم أرصدة الرسائل نماذج من {{modelNames}}. ستعطى الأولوية للحصة المدفوعة. سيتم استخدام الحصة المجانية بعد نفاد الحصة المدفوعة.", "modelProvider.card.tokens": "رموز", + "modelProvider.card.unavailable": "غير متاح", + "modelProvider.card.upgradePlan": "ترقية خطتك", + "modelProvider.card.usageLabel": "الاستخدام", + "modelProvider.card.usagePriority": "أولوية الاستخدام", + "modelProvider.card.usagePriorityTip": "تعيين المورد الذي يجب استخدامه أولاً عند تشغيل النماذج.", "modelProvider.collapse": "طي", "modelProvider.config": "تكوين", "modelProvider.configLoadBalancing": "تكوين موازنة التحميل", @@ -387,9 +406,11 @@ "modelProvider.model": "النموذج", "modelProvider.modelAndParameters": "النموذج والمعلمات", "modelProvider.modelHasBeenDeprecated": "تم إهمال هذا النموذج", + "modelProvider.modelSettings": "إعدادات النموذج", "modelProvider.models": "النماذج", "modelProvider.modelsNum": "{{num}} نماذج", "modelProvider.noModelFound": "لم يتم العثور على نموذج لـ {{model}}", + "modelProvider.noneConfigured": "قم بتكوين نموذج نظام افتراضي لتشغيل التطبيقات", "modelProvider.notConfigured": "لم يتم تكوين نموذج النظام بالكامل بعد", "modelProvider.parameters": "المعلمات", "modelProvider.parametersInvalidRemoved": "بعض المعلمات غير صالحة وتمت إزالتها", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "إعادة التعيين في {{date}}", "modelProvider.searchModel": "نموذج البحث", "modelProvider.selectModel": "اختر نموذجك", + "modelProvider.selector.aiCredits": "أرصدة الذكاء الاصطناعي", + "modelProvider.selector.apiKeyUnavailable": "مفتاح API غير متاح", + "modelProvider.selector.apiKeyUnavailableTip": "تمت إزالة مفتاح API. يرجى تكوين مفتاح API جديد.", + "modelProvider.selector.configure": "تكوين", + "modelProvider.selector.configureRequired": "التكوين مطلوب", + "modelProvider.selector.creditsExhausted": "نفدت الأرصدة", + "modelProvider.selector.creditsExhaustedTip": "نفدت أرصدة الذكاء الاصطناعي الخاصة بك. يرجى ترقية خطتك أو إضافة مفتاح API.", + "modelProvider.selector.disabled": "معطل", + "modelProvider.selector.discoverMoreInMarketplace": "اكتشف المزيد في السوق", "modelProvider.selector.emptySetting": "يرجى الانتقال إلى الإعدادات للتكوين", "modelProvider.selector.emptyTip": "لا توجد نماذج متاحة", + "modelProvider.selector.fromMarketplace": "من السوق", + "modelProvider.selector.incompatible": "غير متوافق", + "modelProvider.selector.incompatibleTip": "هذا النموذج غير متاح في الإصدار الحالي. يرجى تحديد نموذج متاح آخر.", + "modelProvider.selector.install": "تثبيت", + "modelProvider.selector.modelProviderSettings": "إعدادات مزود النموذج", + "modelProvider.selector.noProviderConfigured": "لم يتم تكوين أي مزود نموذج", + "modelProvider.selector.noProviderConfiguredDesc": "تصفح السوق لتثبيت مزود، أو قم بتكوين المزودين في الإعدادات.", + "modelProvider.selector.onlyCompatibleModelsShown": "يتم عرض النماذج المتوافقة فقط", "modelProvider.selector.rerankTip": "يرجى إعداد نموذج إعادة الترتيب", "modelProvider.selector.tip": "تمت إزالة هذا النموذج. يرجى إضافة نموذج أو تحديد نموذج آخر.", "modelProvider.setupModelFirst": "يرجى إعداد نموذجك أولاً", diff --git a/web/i18n/ar-TN/plugin.json b/web/i18n/ar-TN/plugin.json index 39052f4295..c48a15fb24 100644 --- a/web/i18n/ar-TN/plugin.json +++ b/web/i18n/ar-TN/plugin.json @@ -3,6 +3,7 @@ "action.delete": "إزالة الإضافة", "action.deleteContentLeft": "هل ترغب في إزالة ", "action.deleteContentRight": " الإضافة؟", + "action.deleteSuccess": "تم حذف الإضافة بنجاح", "action.pluginInfo": "معلومات الإضافة", "action.usedInApps": "يتم استخدام هذه الإضافة في {{num}} تطبيقات.", "allCategories": "جميع الفئات", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "تثبيت", "detailPanel.operation.remove": "إزالة", "detailPanel.operation.update": "تحديث", + "detailPanel.operation.updateTooltip": "قم بالتحديث للوصول إلى أحدث النماذج.", "detailPanel.operation.viewDetail": "عرض التفاصيل", "detailPanel.serviceOk": "الخدمة جيدة", "detailPanel.strategyNum": "{{num}} {{strategy}} متضمن", @@ -231,12 +233,18 @@ "source.local": "ملف الحزمة المحلية", "source.marketplace": "السوق", "task.clearAll": "مسح الكل", + "task.errorMsg.github": "لم يتم تثبيت هذه الإضافة تلقائيًا.\nيرجى تثبيتها من GitHub.", + "task.errorMsg.marketplace": "لم يتم تثبيت هذه الإضافة تلقائيًا.\nيرجى تثبيتها من السوق.", + "task.errorMsg.unknown": "لم يتم تثبيت هذه الإضافة.\nتعذر التعرف على مصدر الإضافة.", "task.errorPlugins": "فشل في تثبيت الإضافات", "task.installError": "{{errorLength}} إضافات فشل تثبيتها، انقر للعرض", + "task.installFromGithub": "التثبيت من GitHub", + "task.installFromMarketplace": "التثبيت من السوق", "task.installSuccess": "تم تثبيت {{successLength}} من الإضافات بنجاح", "task.installed": "مثبت", "task.installedError": "{{errorLength}} إضافات فشل تثبيتها", "task.installing": "جارٍ تثبيت الإضافات.", + "task.installingHint": "جارٍ التثبيت... قد يستغرق هذا بضع دقائق.", "task.installingWithError": "تثبيت {{installingLength}} إضافات، {{successLength}} نجاح، {{errorLength}} فشل", "task.installingWithSuccess": "تثبيت {{installingLength}} إضافات، {{successLength}} نجاح.", "task.runningPlugins": "تثبيت الإضافات", diff --git a/web/i18n/ar-TN/workflow.json b/web/i18n/ar-TN/workflow.json index 1bb5237ff7..2487538071 100644 --- a/web/i18n/ar-TN/workflow.json +++ b/web/i18n/ar-TN/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "تحديث سير العمل", "error.startNodeRequired": "الرجاء إضافة عقدة البداية أولاً قبل {{operation}}", "errorMsg.authRequired": "الترخيص مطلوب", + "errorMsg.configureModel": "قم بتكوين نموذج", "errorMsg.fieldRequired": "{{field}} مطلوب", "errorMsg.fields.code": "الكود", "errorMsg.fields.model": "النموذج", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "متغير الرؤية", "errorMsg.invalidJson": "{{field}} هو JSON غير صالح", "errorMsg.invalidVariable": "متغير غير صالح", + "errorMsg.modelPluginNotInstalled": "متغير غير صالح. قم بتكوين نموذج لتمكين هذا المتغير.", "errorMsg.noValidTool": "{{field}} لا توجد أداة صالحة محددة", "errorMsg.rerankModelRequired": "مطلوب تكوين نموذج Rerank", "errorMsg.startNodeRequired": "الرجاء إضافة عقدة البداية أولاً قبل {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "حجم النافذة", "nodes.common.outputVars": "متغيرات الإخراج", "nodes.common.pluginNotInstalled": "الإضافة غير مثبتة", + "nodes.common.pluginsNotInstalled": "{{count}} إضافات غير مثبتة", "nodes.common.retry.maxRetries": "الحد الأقصى لإعادة المحاولة", "nodes.common.retry.ms": "مللي ثانية", "nodes.common.retry.retries": "{{num}} إعادة محاولة", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "القطع", "nodes.knowledgeBase.chunksInputTip": "متغير الإدخال لعقدة قاعدة المعرفة هو Pieces. نوع المتغير هو كائن بمخطط JSON محدد يجب أن يكون متسقًا مع هيكل القطعة المحدد.", "nodes.knowledgeBase.chunksVariableIsRequired": "متغير القطع مطلوب", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "مفتاح API غير متاح", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "نفدت الأرصدة", + "nodes.knowledgeBase.embeddingModelIncompatible": "غير متوافق", "nodes.knowledgeBase.embeddingModelIsInvalid": "نموذج التضمين غير صالح", "nodes.knowledgeBase.embeddingModelIsRequired": "نموذج التضمين مطلوب", + "nodes.knowledgeBase.embeddingModelNotConfigured": "نموذج التضمين غير مكوّن", "nodes.knowledgeBase.indexMethodIsRequired": "طريقة الفهرسة مطلوبة", + "nodes.knowledgeBase.notConfigured": "غير مكوّن", "nodes.knowledgeBase.rerankingModelIsInvalid": "نموذج إعادة الترتيب غير صالح", "nodes.knowledgeBase.rerankingModelIsRequired": "نموذج إعادة الترتيب مطلوب", "nodes.knowledgeBase.retrievalSettingIsRequired": "إعداد الاسترجاع مطلوب", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "يدعم Jinja2 فقط", "nodes.templateTransform.inputVars": "متغيرات الإدخال", "nodes.templateTransform.outputVars.output": "المحتوى المحول", + "nodes.tool.authorizationRequired": "التفويض مطلوب", "nodes.tool.authorize": "تخويل", "nodes.tool.inputVars": "متغيرات الإدخال", "nodes.tool.insertPlaceholder1": "اكتب أو اضغط", @@ -1062,10 +1071,12 @@ "panel.change": "تغيير", "panel.changeBlock": "تغيير العقدة", "panel.checklist": "قائمة المراجعة", + "panel.checklistDescription": "حل المشكلات التالية قبل النشر", "panel.checklistResolved": "تم حل جميع المشكلات", "panel.checklistTip": "تأكد من حل جميع المشكلات قبل النشر", "panel.createdBy": "تم الإنشاء بواسطة ", "panel.goTo": "الذهاب إلى", + "panel.goToFix": "الذهاب إلى الإصلاح", "panel.helpLink": "عرض المستندات", "panel.maximize": "تكبير القماش", "panel.minimize": "خروج من وضع ملء الشاشة", diff --git a/web/i18n/de-DE/app-debug.json b/web/i18n/de-DE/app-debug.json index f75982bee9..0e067e20eb 100644 --- a/web/i18n/de-DE/app-debug.json +++ b/web/i18n/de-DE/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "AUSFÜHREN", "inputs.title": "Debug und Vorschau", "inputs.userInputField": "Benutzereingabefeld", + "manageModels": "Modelle verwalten", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Vollständig", "modelConfig.model": "Modell", "modelConfig.setTone": "Ton der Antworten festlegen", "modelConfig.title": "Modell und Parameter", + "noModelProviderConfigured": "Kein Modellanbieter konfiguriert", + "noModelProviderConfiguredTip": "Installieren oder konfigurieren Sie einen Modellanbieter, um zu beginnen.", + "noModelSelected": "Kein Modell ausgewählt", + "noModelSelectedTip": "Konfigurieren Sie oben ein Modell, um fortzufahren.", "noResult": "Hier wird die Ausgabe angezeigt.", "notSetAPIKey.description": "Der LLM-Anbieterschlüssel wurde nicht festgelegt und muss vor dem Debuggen festgelegt werden.", "notSetAPIKey.settingBtn": "Zu den Einstellungen gehen", diff --git a/web/i18n/de-DE/common.json b/web/i18n/de-DE/common.json index a7aaf55086..8639a24f3e 100644 --- a/web/i18n/de-DE/common.json +++ b/web/i18n/de-DE/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Unbefugt", "modelProvider.buyQuota": "Kontingent kaufen", "modelProvider.callTimes": "Anrufzeiten", + "modelProvider.card.aiCreditsInUse": "AI Credits werden verwendet", + "modelProvider.card.aiCreditsOption": "AI Credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API Key erforderlich", + "modelProvider.card.apiKeyUnavailableFallback": "API Key nicht verfügbar, AI Credits werden verwendet", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Überprüfen Sie Ihre API-Key-Konfiguration, um zurückzuwechseln", "modelProvider.card.buyQuota": "Kontingent kaufen", "modelProvider.card.callTimes": "Anrufzeiten", + "modelProvider.card.creditsExhaustedDescription": "Bitte upgraden Sie Ihren Plan oder konfigurieren Sie einen API Key", + "modelProvider.card.creditsExhaustedFallback": "AI Credits aufgebraucht, API Key wird verwendet", + "modelProvider.card.creditsExhaustedFallbackDescription": "Upgraden Sie Ihren Plan, um die AI-Credit-Priorität wiederherzustellen.", + "modelProvider.card.creditsExhaustedMessage": "AI Credits wurden aufgebraucht", "modelProvider.card.modelAPI": "{{modelName}}-Modelle verwenden den API-Schlüssel.", "modelProvider.card.modelNotSupported": "{{modelName}}-Modelle sind nicht installiert.", "modelProvider.card.modelSupported": "{{modelName}}-Modelle verwenden dieses Kontingent.", + "modelProvider.card.noApiKeysDescription": "Fügen Sie einen API Key hinzu, um Ihre eigenen Modell-Zugangsdaten zu verwenden.", + "modelProvider.card.noApiKeysFallback": "Keine API Keys, AI Credits werden verwendet", + "modelProvider.card.noApiKeysTitle": "Noch keine API Keys konfiguriert", + "modelProvider.card.noAvailableUsage": "Kein verfügbares Guthaben", "modelProvider.card.onTrial": "In Probe", "modelProvider.card.paid": "Bezahlt", "modelProvider.card.priorityUse": "Priorisierte Nutzung", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "API-Schlüssel entfernen", "modelProvider.card.tip": "Nachrichtenguthaben unterstützen Modelle von {{modelNames}}. Der bezahlten Kontingent wird Vorrang gegeben. Das kostenlose Kontingent wird nach dem Verbrauch des bezahlten Kontingents verwendet.", "modelProvider.card.tokens": "Token", + "modelProvider.card.unavailable": "Nicht verfügbar", + "modelProvider.card.upgradePlan": "Plan upgraden", + "modelProvider.card.usageLabel": "Verbrauch", + "modelProvider.card.usagePriority": "Nutzungspriorität", + "modelProvider.card.usagePriorityTip": "Legen Sie fest, welche Ressource beim Ausführen von Modellen zuerst verwendet wird.", "modelProvider.collapse": "Einklappen", "modelProvider.config": "Konfigurieren", "modelProvider.configLoadBalancing": "Lastenausgleich für die Konfiguration", @@ -387,9 +406,11 @@ "modelProvider.model": "Modell", "modelProvider.modelAndParameters": "Modell und Parameter", "modelProvider.modelHasBeenDeprecated": "Dieses Modell ist veraltet", + "modelProvider.modelSettings": "Modelleinstellungen", "modelProvider.models": "Modelle", "modelProvider.modelsNum": "{{num}} Modelle", "modelProvider.noModelFound": "Kein Modell für {{model}} gefunden", + "modelProvider.noneConfigured": "Konfigurieren Sie ein Standard-Systemmodell, um Anwendungen auszuführen", "modelProvider.notConfigured": "Das Systemmodell wurde noch nicht vollständig konfiguriert, und einige Funktionen sind möglicherweise nicht verfügbar.", "modelProvider.parameters": "PARAMETER", "modelProvider.parametersInvalidRemoved": "Einige Parameter sind ungültig und wurden entfernt.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Zurücksetzen am {{date}}", "modelProvider.searchModel": "Suchmodell", "modelProvider.selectModel": "Wählen Sie Ihr Modell", + "modelProvider.selector.aiCredits": "AI Credits", + "modelProvider.selector.apiKeyUnavailable": "API Key nicht verfügbar", + "modelProvider.selector.apiKeyUnavailableTip": "Der API Key wurde entfernt. Bitte konfigurieren Sie einen neuen API Key.", + "modelProvider.selector.configure": "Konfigurieren", + "modelProvider.selector.configureRequired": "Konfiguration erforderlich", + "modelProvider.selector.creditsExhausted": "Guthaben aufgebraucht", + "modelProvider.selector.creditsExhaustedTip": "Ihre AI Credits wurden aufgebraucht. Bitte upgraden Sie Ihren Plan oder fügen Sie einen API Key hinzu.", + "modelProvider.selector.disabled": "Deaktiviert", + "modelProvider.selector.discoverMoreInMarketplace": "Mehr im Marketplace entdecken", "modelProvider.selector.emptySetting": "Bitte gehen Sie zu den Einstellungen, um zu konfigurieren", "modelProvider.selector.emptyTip": "Keine verfügbaren Modelle", + "modelProvider.selector.fromMarketplace": "Vom Marketplace", + "modelProvider.selector.incompatible": "Inkompatibel", + "modelProvider.selector.incompatibleTip": "Dieses Modell ist in der aktuellen Version nicht verfügbar. Bitte wählen Sie ein anderes verfügbares Modell.", + "modelProvider.selector.install": "Installieren", + "modelProvider.selector.modelProviderSettings": "Modellanbieter-Einstellungen", + "modelProvider.selector.noProviderConfigured": "Kein Modellanbieter konfiguriert", + "modelProvider.selector.noProviderConfiguredDesc": "Durchsuchen Sie den Marketplace, um einen zu installieren, oder konfigurieren Sie Anbieter in den Einstellungen.", + "modelProvider.selector.onlyCompatibleModelsShown": "Es werden nur kompatible Modelle angezeigt", "modelProvider.selector.rerankTip": "Bitte richten Sie das Rerank-Modell ein", "modelProvider.selector.tip": "Dieses Modell wurde entfernt. Bitte fügen Sie ein Modell hinzu oder wählen Sie ein anderes Modell.", "modelProvider.setupModelFirst": "Bitte richten Sie zuerst Ihr Modell ein", diff --git a/web/i18n/de-DE/plugin.json b/web/i18n/de-DE/plugin.json index c0158be29b..e805f3fabd 100644 --- a/web/i18n/de-DE/plugin.json +++ b/web/i18n/de-DE/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Plugin entfernen", "action.deleteContentLeft": "Möchten Sie", "action.deleteContentRight": "Plugin?", + "action.deleteSuccess": "Plugin erfolgreich entfernt", "action.pluginInfo": "Plugin-Info", "action.usedInApps": "Dieses Plugin wird in {{num}} Apps verwendet.", "allCategories": "Alle Kategorien", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Installieren", "detailPanel.operation.remove": "Entfernen", "detailPanel.operation.update": "Aktualisieren", + "detailPanel.operation.updateTooltip": "Aktualisieren Sie, um auf die neuesten Modelle zuzugreifen.", "detailPanel.operation.viewDetail": "Im Detail sehen", "detailPanel.serviceOk": "Service in Ordnung", "detailPanel.strategyNum": "{{num}} {{strategy}} IINKLUSIVE", @@ -231,12 +233,18 @@ "source.local": "Lokale Paketdatei", "source.marketplace": "Marktplatz", "task.clearAll": "Alle löschen", + "task.errorMsg.github": "Dieses Plugin konnte nicht automatisch installiert werden.\nBitte installieren Sie es von GitHub.", + "task.errorMsg.marketplace": "Dieses Plugin konnte nicht automatisch installiert werden.\nBitte installieren Sie es vom Marketplace.", + "task.errorMsg.unknown": "Dieses Plugin konnte nicht installiert werden.\nDie Plugin-Quelle konnte nicht identifiziert werden.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} Plugins konnten nicht installiert werden, klicken Sie hier, um sie anzusehen", + "task.installFromGithub": "Von GitHub installieren", + "task.installFromMarketplace": "Vom Marketplace installieren", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} Plugins konnten nicht installiert werden", "task.installing": "Plugins werden installiert.", + "task.installingHint": "Installation läuft... Dies kann einige Minuten dauern.", "task.installingWithError": "Installation von {{installingLength}} Plugins, {{successLength}} erfolgreich, {{errorLength}} fehlgeschlagen", "task.installingWithSuccess": "Installation von {{installingLength}} Plugins, {{successLength}} erfolgreich.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/de-DE/workflow.json b/web/i18n/de-DE/workflow.json index 48037ce006..362eac19b6 100644 --- a/web/i18n/de-DE/workflow.json +++ b/web/i18n/de-DE/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "Arbeitsablauf aktualisieren", "error.startNodeRequired": "Bitte füge zuerst einen Startknoten hinzu, bevor du {{operation}}.", "errorMsg.authRequired": "Autorisierung ist erforderlich", + "errorMsg.configureModel": "Modell konfigurieren", "errorMsg.fieldRequired": "{{field}} ist erforderlich", "errorMsg.fields.code": "Code", "errorMsg.fields.model": "Modell", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vision variabel", "errorMsg.invalidJson": "{{field}} ist ein ungültiges JSON", "errorMsg.invalidVariable": "Ungültige Variable", + "errorMsg.modelPluginNotInstalled": "Ungültige Variable. Konfigurieren Sie ein Modell, um diese Variable zu aktivieren.", "errorMsg.noValidTool": "{{field}} kein gültiges Werkzeug ausgewählt", "errorMsg.rerankModelRequired": "Bevor Sie das Rerank-Modell aktivieren, bestätigen Sie bitte, dass das Modell in den Einstellungen erfolgreich konfiguriert wurde.", "errorMsg.startNodeRequired": "Bitte füge zuerst einen Startknoten hinzu, bevor du {{operation}}.", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Fenstergröße", "nodes.common.outputVars": "Ausgabevariablen", "nodes.common.pluginNotInstalled": "Plugin ist nicht installiert", + "nodes.common.pluginsNotInstalled": "{{count}} Plugins nicht installiert", "nodes.common.retry.maxRetries": "Max. Wiederholungen", "nodes.common.retry.ms": "Frau", "nodes.common.retry.retries": "{{num}} Wiederholungen", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Stücke", "nodes.knowledgeBase.chunksInputTip": "Die Eingangsvariable des Wissensbasis-Knotens sind Chunks. Der Variablentyp ist ein Objekt mit einem spezifischen JSON-Schema, das konsistent mit der ausgewählten Chunk-Struktur sein muss.", "nodes.knowledgeBase.chunksVariableIsRequired": "Die Variable 'Chunks' ist erforderlich", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key nicht verfügbar", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Guthaben aufgebraucht", + "nodes.knowledgeBase.embeddingModelIncompatible": "Inkompatibel", "nodes.knowledgeBase.embeddingModelIsInvalid": "Einbettungsmodell ist ungültig", "nodes.knowledgeBase.embeddingModelIsRequired": "Ein Einbettungsmodell ist erforderlich", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Embedding-Modell nicht konfiguriert", "nodes.knowledgeBase.indexMethodIsRequired": "Index-Methode ist erforderlich", + "nodes.knowledgeBase.notConfigured": "Nicht konfiguriert", "nodes.knowledgeBase.rerankingModelIsInvalid": "Das Reranking-Modell ist ungültig", "nodes.knowledgeBase.rerankingModelIsRequired": "Ein Reranking-Modell ist erforderlich", "nodes.knowledgeBase.retrievalSettingIsRequired": "Abrufeinstellung ist erforderlich", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Unterstützt nur Jinja2", "nodes.templateTransform.inputVars": "Eingabevariablen", "nodes.templateTransform.outputVars.output": "Transformierter Inhalt", + "nodes.tool.authorizationRequired": "Autorisierung erforderlich", "nodes.tool.authorize": "Autorisieren", "nodes.tool.inputVars": "Eingabevariablen", "nodes.tool.insertPlaceholder1": "Tippen oder drücken", @@ -1062,10 +1071,12 @@ "panel.change": "Ändern", "panel.changeBlock": "Knoten ändern", "panel.checklist": "Checkliste", + "panel.checklistDescription": "Beheben Sie die folgenden Probleme vor der Veröffentlichung", "panel.checklistResolved": "Alle Probleme wurden gelöst", "panel.checklistTip": "Stellen Sie sicher, dass alle Probleme vor der Veröffentlichung gelöst sind", "panel.createdBy": "Erstellt von ", "panel.goTo": "Gehe zu", + "panel.goToFix": "Zur Behebung", "panel.helpLink": "Hilfe", "panel.maximize": "Maximiere die Leinwand", "panel.minimize": "Vollbildmodus beenden", diff --git a/web/i18n/es-ES/app-debug.json b/web/i18n/es-ES/app-debug.json index 8245f2b325..cfb8a0643e 100644 --- a/web/i18n/es-ES/app-debug.json +++ b/web/i18n/es-ES/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "EJECUTAR", "inputs.title": "Depurar y Previsualizar", "inputs.userInputField": "Campo de Entrada del Usuario", + "manageModels": "Gestionar modelos", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Completar", "modelConfig.model": "Modelo", "modelConfig.setTone": "Establecer tono de respuestas", "modelConfig.title": "Modelo y Parámetros", + "noModelProviderConfigured": "Ningún proveedor de modelos configurado", + "noModelProviderConfiguredTip": "Instala o configura un proveedor de modelos para comenzar.", + "noModelSelected": "Ningún modelo seleccionado", + "noModelSelectedTip": "Configura un modelo arriba para continuar.", "noResult": "La salida se mostrará aquí.", "notSetAPIKey.description": "La clave del proveedor LLM no se ha establecido, y debe configurarse antes de depurar.", "notSetAPIKey.settingBtn": "Ir a configuración", diff --git a/web/i18n/es-ES/common.json b/web/i18n/es-ES/common.json index 595124496b..1b97ce680d 100644 --- a/web/i18n/es-ES/common.json +++ b/web/i18n/es-ES/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "No autorizado", "modelProvider.buyQuota": "Comprar Cuota", "modelProvider.callTimes": "Tiempos de llamada", + "modelProvider.card.aiCreditsInUse": "AI credits en uso", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Se requiere API Key", + "modelProvider.card.apiKeyUnavailableFallback": "API Key no disponible, usando AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Comprueba la configuración de tu API Key para volver a usarla", "modelProvider.card.buyQuota": "Comprar Cuota", "modelProvider.card.callTimes": "Tiempos de llamada", + "modelProvider.card.creditsExhaustedDescription": "Por favor, mejora tu plan o configura una API Key", + "modelProvider.card.creditsExhaustedFallback": "AI credits agotados, usando API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Mejora tu plan para restablecer la prioridad de AI credits.", + "modelProvider.card.creditsExhaustedMessage": "Los AI credits se han agotado", "modelProvider.card.modelAPI": "Los modelos {{modelName}} están usando la clave API.", "modelProvider.card.modelNotSupported": "Los modelos {{modelName}} no están instalados.", "modelProvider.card.modelSupported": "Los modelos {{modelName}} están usando esta cuota.", + "modelProvider.card.noApiKeysDescription": "Añade una API Key para usar tus propias credenciales de modelo.", + "modelProvider.card.noApiKeysFallback": "Sin API Keys, usando AI credits", + "modelProvider.card.noApiKeysTitle": "No hay API Keys configuradas", + "modelProvider.card.noAvailableUsage": "Sin uso disponible", "modelProvider.card.onTrial": "En prueba", "modelProvider.card.paid": "Pagado", "modelProvider.card.priorityUse": "Uso prioritario", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Eliminar CLAVE API", "modelProvider.card.tip": "Créditos de mensajes admite modelos de {{modelNames}}. Se dará prioridad a la cuota pagada. La cuota gratuita se utilizará después de que se agote la cuota pagada.", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "No disponible", + "modelProvider.card.upgradePlan": "mejora tu plan", + "modelProvider.card.usageLabel": "Uso", + "modelProvider.card.usagePriority": "Prioridad de uso", + "modelProvider.card.usagePriorityTip": "Establece qué recurso usar primero al ejecutar modelos.", "modelProvider.collapse": "Colapsar", "modelProvider.config": "Configurar", "modelProvider.configLoadBalancing": "Configurar Balanceo de Carga", @@ -387,9 +406,11 @@ "modelProvider.model": "Modelo", "modelProvider.modelAndParameters": "Modelo y Parámetros", "modelProvider.modelHasBeenDeprecated": "Este modelo ha sido desaprobado", + "modelProvider.modelSettings": "Configuración de modelos", "modelProvider.models": "Modelos", "modelProvider.modelsNum": "{{num}} Modelos", "modelProvider.noModelFound": "No se encontró modelo para {{model}}", + "modelProvider.noneConfigured": "Configura un modelo de sistema predeterminado para ejecutar aplicaciones", "modelProvider.notConfigured": "El modelo del sistema aún no ha sido completamente configurado, y algunas funciones pueden no estar disponibles.", "modelProvider.parameters": "PARÁMETROS", "modelProvider.parametersInvalidRemoved": "Algunos parámetros son inválidos y han sido eliminados", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Restablecer el {{date}}", "modelProvider.searchModel": "Modelo de búsqueda", "modelProvider.selectModel": "Selecciona tu modelo", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key no disponible", + "modelProvider.selector.apiKeyUnavailableTip": "La API Key ha sido eliminada. Por favor, configura una nueva API Key.", + "modelProvider.selector.configure": "Configurar", + "modelProvider.selector.configureRequired": "Configuración requerida", + "modelProvider.selector.creditsExhausted": "Créditos agotados", + "modelProvider.selector.creditsExhaustedTip": "Tus AI credits se han agotado. Por favor, mejora tu plan o añade una API Key.", + "modelProvider.selector.disabled": "Desactivado", + "modelProvider.selector.discoverMoreInMarketplace": "Descubre más en el Marketplace", "modelProvider.selector.emptySetting": "Por favor ve a configuraciones para configurar", "modelProvider.selector.emptyTip": "No hay modelos disponibles", + "modelProvider.selector.fromMarketplace": "Desde el Marketplace", + "modelProvider.selector.incompatible": "Incompatible", + "modelProvider.selector.incompatibleTip": "Este modelo no está disponible en la versión actual. Por favor, selecciona otro modelo disponible.", + "modelProvider.selector.install": "Instalar", + "modelProvider.selector.modelProviderSettings": "Configuración del proveedor de modelos", + "modelProvider.selector.noProviderConfigured": "Ningún proveedor de modelos configurado", + "modelProvider.selector.noProviderConfiguredDesc": "Explora el Marketplace para instalar uno, o configura proveedores en los ajustes.", + "modelProvider.selector.onlyCompatibleModelsShown": "Solo se muestran los modelos compatibles", "modelProvider.selector.rerankTip": "Por favor configura el modelo de Reordenar", "modelProvider.selector.tip": "Este modelo ha sido eliminado. Por favor agrega un modelo o selecciona otro modelo.", "modelProvider.setupModelFirst": "Por favor configura tu modelo primero", diff --git a/web/i18n/es-ES/plugin.json b/web/i18n/es-ES/plugin.json index 297b86606d..c7d5855f24 100644 --- a/web/i18n/es-ES/plugin.json +++ b/web/i18n/es-ES/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Eliminar plugin", "action.deleteContentLeft": "¿Le gustaría eliminar", "action.deleteContentRight": "¿Complemento?", + "action.deleteSuccess": "Plugin eliminado correctamente", "action.pluginInfo": "Información del plugin", "action.usedInApps": "Este plugin se está utilizando en las aplicaciones {{num}}.", "allCategories": "Todas las categorías", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Instalar", "detailPanel.operation.remove": "Eliminar", "detailPanel.operation.update": "Actualizar", + "detailPanel.operation.updateTooltip": "Actualiza para acceder a los modelos más recientes.", "detailPanel.operation.viewDetail": "Ver Detalle", "detailPanel.serviceOk": "Servicio OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUIDO", @@ -231,12 +233,18 @@ "source.local": "Archivo de paquete local", "source.marketplace": "Mercado", "task.clearAll": "Borrar todo", + "task.errorMsg.github": "Este plugin no se pudo instalar automáticamente.\nPor favor, instálalo desde GitHub.", + "task.errorMsg.marketplace": "Este plugin no se pudo instalar automáticamente.\nPor favor, instálalo desde el Marketplace.", + "task.errorMsg.unknown": "Este plugin no se pudo instalar.\nNo se pudo identificar el origen del plugin.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Los complementos {{errorLength}} no se pudieron instalar, haga clic para ver", + "task.installFromGithub": "Instalar desde GitHub", + "task.installFromMarketplace": "Instalar desde el Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Los complementos {{errorLength}} no se pudieron instalar", "task.installing": "Instalando plugins.", + "task.installingHint": "Instalando... Esto puede tardar unos minutos.", "task.installingWithError": "Instalando plugins {{installingLength}}, {{successLength}} éxito, {{errorLength}} fallido", "task.installingWithSuccess": "Instalando plugins {{installingLength}}, {{successLength}} éxito.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/es-ES/workflow.json b/web/i18n/es-ES/workflow.json index d91cbd7cb2..393859a36f 100644 --- a/web/i18n/es-ES/workflow.json +++ b/web/i18n/es-ES/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "actualizando flujo de trabajo", "error.startNodeRequired": "Por favor, agregue primero un nodo de inicio antes de {{operation}}", "errorMsg.authRequired": "Se requiere autorización", + "errorMsg.configureModel": "Configura un modelo", "errorMsg.fieldRequired": "Se requiere {{field}}", "errorMsg.fields.code": "Código", "errorMsg.fields.model": "Modelo", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Variable de visión", "errorMsg.invalidJson": "{{field}} no es un JSON válido", "errorMsg.invalidVariable": "Variable no válida", + "errorMsg.modelPluginNotInstalled": "Variable no válida. Configura un modelo para habilitar esta variable.", "errorMsg.noValidTool": "{{field}} no se ha seleccionado ninguna herramienta válida", "errorMsg.rerankModelRequired": "Antes de activar el modelo de reclasificación, confirme que el modelo se ha configurado correctamente en la configuración.", "errorMsg.startNodeRequired": "Por favor, agregue primero un nodo de inicio antes de {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Tamaño de ventana", "nodes.common.outputVars": "Variables de salida", "nodes.common.pluginNotInstalled": "El complemento no está instalado", + "nodes.common.pluginsNotInstalled": "{{count}} plugins no instalados", "nodes.common.retry.maxRetries": "Número máximo de reintentos", "nodes.common.retry.ms": "Sra.", "nodes.common.retry.retries": "{{num}} Reintentos", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Trozo", "nodes.knowledgeBase.chunksInputTip": "La variable de entrada del nodo de la base de conocimientos es Chunks. El tipo de variable es un objeto con un esquema JSON específico que debe ser consistente con la estructura del fragmento seleccionado.", "nodes.knowledgeBase.chunksVariableIsRequired": "La variable Chunks es obligatoria", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key no disponible", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Créditos agotados", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatible", "nodes.knowledgeBase.embeddingModelIsInvalid": "El modelo de incrustación no es válido", "nodes.knowledgeBase.embeddingModelIsRequired": "Se requiere un modelo de incrustación", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modelo de embedding no configurado", "nodes.knowledgeBase.indexMethodIsRequired": "Se requiere el método de índice", + "nodes.knowledgeBase.notConfigured": "No configurado", "nodes.knowledgeBase.rerankingModelIsInvalid": "El modelo de reordenación no es válido", "nodes.knowledgeBase.rerankingModelIsRequired": "Se requiere un modelo de reordenamiento", "nodes.knowledgeBase.retrievalSettingIsRequired": "Se requiere configuración de recuperación", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Solo admite Jinja2", "nodes.templateTransform.inputVars": "Variables de entrada", "nodes.templateTransform.outputVars.output": "Contenido transformado", + "nodes.tool.authorizationRequired": "Autorización requerida", "nodes.tool.authorize": "autorizar", "nodes.tool.inputVars": "Variables de entrada", "nodes.tool.insertPlaceholder1": "Escribe o presiona", @@ -1062,10 +1071,12 @@ "panel.change": "Cambiar", "panel.changeBlock": "Cambiar Nodo", "panel.checklist": "Lista de verificación", + "panel.checklistDescription": "Resuelve los siguientes problemas antes de publicar", "panel.checklistResolved": "Se resolvieron todos los problemas", "panel.checklistTip": "Asegúrate de resolver todos los problemas antes de publicar", "panel.createdBy": "Creado por ", "panel.goTo": "Ir a", + "panel.goToFix": "Ir a corregir", "panel.helpLink": "Ayuda", "panel.maximize": "Maximizar Canvas", "panel.minimize": "Salir de pantalla completa", diff --git a/web/i18n/fa-IR/app-debug.json b/web/i18n/fa-IR/app-debug.json index 5427ebb72a..a612105f35 100644 --- a/web/i18n/fa-IR/app-debug.json +++ b/web/i18n/fa-IR/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "اجرا", "inputs.title": "اشکال زدایی و پیش نمایش", "inputs.userInputField": "فیلد ورودی کاربر", + "manageModels": "مدیریت مدل‌ها", "modelConfig.modeType.chat": "چت", "modelConfig.modeType.completion": "کامل", "modelConfig.model": "مدل", "modelConfig.setTone": "لحن پاسخ ها را تنظیم کنید", "modelConfig.title": "مدل و پارامترها", + "noModelProviderConfigured": "هیچ ارائه‌دهنده مدلی پیکربندی نشده است", + "noModelProviderConfiguredTip": "برای شروع، یک ارائه‌دهنده مدل نصب یا پیکربندی کنید.", + "noModelSelected": "هیچ مدلی انتخاب نشده است", + "noModelSelectedTip": "برای ادامه، یک مدل را در بالا پیکربندی کنید.", "noResult": "خروجی در اینجا نمایش داده می شود.", "notSetAPIKey.description": "کلید ارائه‌دهنده LLM تنظیم نشده است و باید قبل از دیباگ تنظیم شود.", "notSetAPIKey.settingBtn": "به تنظیمات بروید", diff --git a/web/i18n/fa-IR/common.json b/web/i18n/fa-IR/common.json index ec535138a4..d2b1e8158c 100644 --- a/web/i18n/fa-IR/common.json +++ b/web/i18n/fa-IR/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "بدون مجوز", "modelProvider.buyQuota": "خرید سهمیه", "modelProvider.callTimes": "تعداد فراخوانی", + "modelProvider.card.aiCreditsInUse": "اعتبار هوش مصنوعی در حال استفاده", + "modelProvider.card.aiCreditsOption": "اعتبار هوش مصنوعی", + "modelProvider.card.apiKeyOption": "کلید API", + "modelProvider.card.apiKeyRequired": "کلید API الزامی است", + "modelProvider.card.apiKeyUnavailableFallback": "کلید API در دسترس نیست، در حال استفاده از اعتبار هوش مصنوعی", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "پیکربندی کلید API خود را بررسی کنید تا بازگردید", "modelProvider.card.buyQuota": "خرید سهمیه", "modelProvider.card.callTimes": "تعداد فراخوانی", + "modelProvider.card.creditsExhaustedDescription": "لطفاً طرح خود را ارتقا دهید یا یک کلید API پیکربندی کنید", + "modelProvider.card.creditsExhaustedFallback": "اعتبار هوش مصنوعی تمام شده، در حال استفاده از کلید API", + "modelProvider.card.creditsExhaustedFallbackDescription": "طرح خود را ارتقا دهید تا اولویت اعتبار هوش مصنوعی از سر گرفته شود.", + "modelProvider.card.creditsExhaustedMessage": "اعتبار هوش مصنوعی تمام شده است", "modelProvider.card.modelAPI": "مدل‌های {{modelName}} از کلید API استفاده می‌کنند.", "modelProvider.card.modelNotSupported": "مدل‌های {{modelName}} نصب نشده‌اند.", "modelProvider.card.modelSupported": "مدل‌های {{modelName}} از این سهمیه استفاده می‌کنند.", + "modelProvider.card.noApiKeysDescription": "یک کلید API اضافه کنید تا از اعتبارنامه‌های مدل خود استفاده کنید.", + "modelProvider.card.noApiKeysFallback": "بدون کلید API، در حال استفاده از اعتبار هوش مصنوعی", + "modelProvider.card.noApiKeysTitle": "هنوز کلید API پیکربندی نشده است", + "modelProvider.card.noAvailableUsage": "هیچ مصرفی در دسترس نیست", "modelProvider.card.onTrial": "در حال آزمایش", "modelProvider.card.paid": "پرداخت شده", "modelProvider.card.priorityUse": "استفاده با اولویت", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "حذف کلید API", "modelProvider.card.tip": "اعتبار پیام از مدل‌های {{modelNames}} پشتیبانی می‌کند. اولویت به سهمیه پرداخت شده داده می‌شود. سهمیه رایگان پس از اتمام سهمیه پرداخت شده استفاده خواهد شد.", "modelProvider.card.tokens": "توکن‌ها", + "modelProvider.card.unavailable": "در دسترس نیست", + "modelProvider.card.upgradePlan": "طرح خود را ارتقا دهید", + "modelProvider.card.usageLabel": "مصرف", + "modelProvider.card.usagePriority": "اولویت مصرف", + "modelProvider.card.usagePriorityTip": "تعیین کنید که هنگام اجرای مدل‌ها کدام منبع اول استفاده شود.", "modelProvider.collapse": "جمع کردن", "modelProvider.config": "پیکربندی", "modelProvider.configLoadBalancing": "پیکربندی تعادل بار", @@ -387,9 +406,11 @@ "modelProvider.model": "مدل", "modelProvider.modelAndParameters": "مدل و پارامترها", "modelProvider.modelHasBeenDeprecated": "این مدل منسوخ شده است", + "modelProvider.modelSettings": "تنظیمات مدل", "modelProvider.models": "مدل‌ها", "modelProvider.modelsNum": "{{num}} مدل", "modelProvider.noModelFound": "هیچ مدلی برای {{model}} یافت نشد", + "modelProvider.noneConfigured": "یک مدل سیستم پیش‌فرض برای اجرای برنامه‌ها پیکربندی کنید", "modelProvider.notConfigured": "مدل سیستم هنوز به طور کامل پیکربندی نشده است و برخی از عملکردها ممکن است در دسترس نباشند.", "modelProvider.parameters": "پارامترها", "modelProvider.parametersInvalidRemoved": "برخی پارامترها نامعتبر هستند و حذف شده‌اند", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "بازنشانی در {{date}}", "modelProvider.searchModel": "جستجوی مدل", "modelProvider.selectModel": "مدل خود را انتخاب کنید", + "modelProvider.selector.aiCredits": "اعتبار هوش مصنوعی", + "modelProvider.selector.apiKeyUnavailable": "کلید API در دسترس نیست", + "modelProvider.selector.apiKeyUnavailableTip": "کلید API حذف شده است. لطفاً یک کلید API جدید پیکربندی کنید.", + "modelProvider.selector.configure": "پیکربندی", + "modelProvider.selector.configureRequired": "پیکربندی الزامی است", + "modelProvider.selector.creditsExhausted": "اعتبار تمام شده", + "modelProvider.selector.creditsExhaustedTip": "اعتبار هوش مصنوعی شما تمام شده است. لطفاً طرح خود را ارتقا دهید یا یک کلید API اضافه کنید.", + "modelProvider.selector.disabled": "غیرفعال", + "modelProvider.selector.discoverMoreInMarketplace": "اطلاعات بیشتر در Marketplace", "modelProvider.selector.emptySetting": "لطفاً به تنظیمات بروید تا پیکربندی کنید", "modelProvider.selector.emptyTip": "هیچ مدل موجودی وجود ندارد", + "modelProvider.selector.fromMarketplace": "از Marketplace", + "modelProvider.selector.incompatible": "ناسازگار", + "modelProvider.selector.incompatibleTip": "این مدل در نسخه فعلی موجود نیست. لطفاً مدل دیگری انتخاب کنید.", + "modelProvider.selector.install": "نصب", + "modelProvider.selector.modelProviderSettings": "تنظیمات ارائه‌دهنده مدل", + "modelProvider.selector.noProviderConfigured": "هیچ ارائه‌دهنده مدلی پیکربندی نشده است", + "modelProvider.selector.noProviderConfiguredDesc": "در Marketplace جستجو کنید تا یکی نصب کنید، یا ارائه‌دهندگان را در تنظیمات پیکربندی کنید.", + "modelProvider.selector.onlyCompatibleModelsShown": "فقط مدل‌های سازگار نمایش داده می‌شوند", "modelProvider.selector.rerankTip": "لطفاً مدل رتبه‌بندی مجدد را تنظیم کنید", "modelProvider.selector.tip": "این مدل حذف شده است. لطفاً یک مدل اضافه کنید یا مدل دیگری را انتخاب کنید.", "modelProvider.setupModelFirst": "لطفاً ابتدا مدل خود را تنظیم کنید", diff --git a/web/i18n/fa-IR/plugin.json b/web/i18n/fa-IR/plugin.json index a1bbc8ee63..bee2111b44 100644 --- a/web/i18n/fa-IR/plugin.json +++ b/web/i18n/fa-IR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "حذف افزونه", "action.deleteContentLeft": "آیا می خواهید", "action.deleteContentRight": "افزونه?", + "action.deleteSuccess": "افزونه با موفقیت حذف شد", "action.pluginInfo": "اطلاعات پلاگین", "action.usedInApps": "این افزونه در برنامه های {{num}} استفاده می شود.", "allCategories": "همه دسته بندی ها", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "نصب", "detailPanel.operation.remove": "حذف", "detailPanel.operation.update": "روز رسانی", + "detailPanel.operation.updateTooltip": "به‌روزرسانی کنید تا به جدیدترین مدل‌ها دسترسی پیدا کنید.", "detailPanel.operation.viewDetail": "نمایش جزئیات", "detailPanel.serviceOk": "خدمات خوب", "detailPanel.strategyNum": "{{num}} {{strategy}} شامل", @@ -231,12 +233,18 @@ "source.local": "فایل بسته محلی", "source.marketplace": "بازار", "task.clearAll": "پاک کردن همه", + "task.errorMsg.github": "این افزونه به‌صورت خودکار نصب نشد.\nلطفاً آن را از GitHub نصب کنید.", + "task.errorMsg.marketplace": "این افزونه به‌صورت خودکار نصب نشد.\nلطفاً آن را از Marketplace نصب کنید.", + "task.errorMsg.unknown": "این افزونه نصب نشد.\nمنبع افزونه قابل شناسایی نبود.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "پلاگین های {{errorLength}} نصب نشدند، برای مشاهده کلیک کنید", + "task.installFromGithub": "نصب از GitHub", + "task.installFromMarketplace": "نصب از Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "افزونه های {{errorLength}} نصب نشدند", "task.installing": "در حال نصب پلاگین‌ها.", + "task.installingHint": "در حال نصب... این ممکن است چند دقیقه طول بکشد.", "task.installingWithError": "نصب پلاگین های {{installingLength}}، {{successLength}} با موفقیت مواجه شد، {{errorLength}} ناموفق بود", "task.installingWithSuccess": "نصب پلاگین های {{installingLength}}، {{successLength}} موفقیت آمیز است.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/fa-IR/workflow.json b/web/i18n/fa-IR/workflow.json index 676179e6d4..7a8fca11f1 100644 --- a/web/i18n/fa-IR/workflow.json +++ b/web/i18n/fa-IR/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "به‌روزرسانی گردش کار", "error.startNodeRequired": "لطفاً قبل از {{operation}} ابتدا یک گره شروع اضافه کنید", "errorMsg.authRequired": "احراز هویت الزامی است", + "errorMsg.configureModel": "یک مدل پیکربندی کنید", "errorMsg.fieldRequired": "{{field}} الزامی است", "errorMsg.fields.code": "کد", "errorMsg.fields.model": "مدل", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "متغیر بینایی", "errorMsg.invalidJson": "{{field}} یک JSON معتبر نیست", "errorMsg.invalidVariable": "متغیر نامعتبر", + "errorMsg.modelPluginNotInstalled": "متغیر نامعتبر. برای فعال‌سازی این متغیر، یک مدل پیکربندی کنید.", "errorMsg.noValidTool": "{{field}} هیچ ابزار معتبری انتخاب نشده است", "errorMsg.rerankModelRequired": "قبل از فعال‌سازی Rerank Model، لطفاً مطمئن شوید که مدل در تنظیمات با موفقیت پیکربندی شده است.", "errorMsg.startNodeRequired": "لطفاً قبل از {{operation}} ابتدا یک گره شروع اضافه کنید", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "اندازه پنجره", "nodes.common.outputVars": "متغیرهای خروجی", "nodes.common.pluginNotInstalled": "افزونه نصب نشده است", + "nodes.common.pluginsNotInstalled": "{{count}} افزونه نصب نشده است", "nodes.common.retry.maxRetries": "حداکثر تلاش مجدد", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} تلاش مجدد", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "چانک‌ها", "nodes.knowledgeBase.chunksInputTip": "متغیر ورودی گره پایگاه دانش چانک‌ها است. نوع متغیر یک شیء با طرح JSON خاص است که باید با ساختار چانک انتخاب‌شده سازگار باشد.", "nodes.knowledgeBase.chunksVariableIsRequired": "متغیر چانک‌ها الزامی است", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "کلید API در دسترس نیست", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "اعتبار تمام شده", + "nodes.knowledgeBase.embeddingModelIncompatible": "ناسازگار", "nodes.knowledgeBase.embeddingModelIsInvalid": "مدل Embedding نامعتبر است", "nodes.knowledgeBase.embeddingModelIsRequired": "مدل Embedding الزامی است", + "nodes.knowledgeBase.embeddingModelNotConfigured": "مدل Embedding پیکربندی نشده است", "nodes.knowledgeBase.indexMethodIsRequired": "روش ایندکس‌گذاری الزامی است", + "nodes.knowledgeBase.notConfigured": "پیکربندی نشده", "nodes.knowledgeBase.rerankingModelIsInvalid": "مدل بازرتبه‌بندی نامعتبر است", "nodes.knowledgeBase.rerankingModelIsRequired": "مدل بازرتبه‌بندی الزامی است", "nodes.knowledgeBase.retrievalSettingIsRequired": "تنظیمات بازیابی الزامی است", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "فقط از Jinja2 پشتیبانی می‌شود", "nodes.templateTransform.inputVars": "متغیرهای ورودی", "nodes.templateTransform.outputVars.output": "محتوای تبدیل‌شده", + "nodes.tool.authorizationRequired": "احراز هویت الزامی است", "nodes.tool.authorize": "مجوزدهی", "nodes.tool.inputVars": "متغیرهای ورودی", "nodes.tool.insertPlaceholder1": "تایپ کنید یا فشار دهید", @@ -1062,10 +1071,12 @@ "panel.change": "تغییر", "panel.changeBlock": "تغییر گره", "panel.checklist": "چک‌لیست", + "panel.checklistDescription": "مشکلات زیر را پیش از انتشار برطرف کنید", "panel.checklistResolved": "تمام مشکلات برطرف شده‌اند", "panel.checklistTip": "قبل از انتشار، مطمئن شوید که تمام مشکلات برطرف شده‌اند", "panel.createdBy": "ساخته شده توسط", "panel.goTo": "برو به", + "panel.goToFix": "برو به رفع", "panel.helpLink": "راهنما", "panel.maximize": "تمام‌صفحه", "panel.minimize": "خروج از تمام‌صفحه", diff --git a/web/i18n/fr-FR/app-debug.json b/web/i18n/fr-FR/app-debug.json index 88f75f5136..711959c323 100644 --- a/web/i18n/fr-FR/app-debug.json +++ b/web/i18n/fr-FR/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "EXÉCUTER", "inputs.title": "Déboguer et Aperçu", "inputs.userInputField": "Champ de saisie utilisateur", + "manageModels": "Gérer les modèles", "modelConfig.modeType.chat": "Discussion", "modelConfig.modeType.completion": "Complet", "modelConfig.model": "Modèle", "modelConfig.setTone": "Définir le ton des réponses", "modelConfig.title": "Modèle et Paramètres", + "noModelProviderConfigured": "Aucun fournisseur de modèle configuré", + "noModelProviderConfiguredTip": "Installez ou configurez un fournisseur de modèle pour commencer.", + "noModelSelected": "Aucun modèle sélectionné", + "noModelSelectedTip": "Configurez un modèle ci-dessus pour continuer.", "noResult": "La sortie sera affichée ici.", "notSetAPIKey.description": "La clé du fournisseur LLM n'a pas été définie, et elle doit être définie avant le débogage.", "notSetAPIKey.settingBtn": "Aller aux paramètres", diff --git a/web/i18n/fr-FR/common.json b/web/i18n/fr-FR/common.json index 42b40b2a2a..8710c04c44 100644 --- a/web/i18n/fr-FR/common.json +++ b/web/i18n/fr-FR/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Non autorisé", "modelProvider.buyQuota": "Acheter Quota", "modelProvider.callTimes": "Temps d'appel", + "modelProvider.card.aiCreditsInUse": "AI credits en cours d'utilisation", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API Key requise", + "modelProvider.card.apiKeyUnavailableFallback": "API Key indisponible, utilisation des AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Vérifiez la configuration de votre API Key pour revenir en arrière", "modelProvider.card.buyQuota": "Acheter Quota", "modelProvider.card.callTimes": "Temps d'appel", + "modelProvider.card.creditsExhaustedDescription": "Veuillez mettre à niveau votre forfait ou configurer une API Key", + "modelProvider.card.creditsExhaustedFallback": "AI credits épuisés, utilisation de l'API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Mettez à niveau votre forfait pour rétablir la priorité des AI credits.", + "modelProvider.card.creditsExhaustedMessage": "Les AI credits ont été épuisés", "modelProvider.card.modelAPI": "Les modèles {{modelName}} utilisent la clé API.", "modelProvider.card.modelNotSupported": "Les modèles {{modelName}} ne sont pas installés.", "modelProvider.card.modelSupported": "Les modèles {{modelName}} utilisent ce quota.", + "modelProvider.card.noApiKeysDescription": "Ajoutez une API Key pour utiliser vos propres identifiants de modèle.", + "modelProvider.card.noApiKeysFallback": "Aucune API Key, utilisation des AI credits", + "modelProvider.card.noApiKeysTitle": "Aucune API Key configurée", + "modelProvider.card.noAvailableUsage": "Aucune utilisation disponible", "modelProvider.card.onTrial": "En Essai", "modelProvider.card.paid": "Payé", "modelProvider.card.priorityUse": "Utilisation prioritaire", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Supprimer la clé API", "modelProvider.card.tip": "Les crédits de messages prennent en charge les modèles de {{modelNames}}. La priorité sera donnée au quota payant. Le quota gratuit sera utilisé après épuisement du quota payant.", "modelProvider.card.tokens": "Jetons", + "modelProvider.card.unavailable": "Indisponible", + "modelProvider.card.upgradePlan": "mettre à niveau votre forfait", + "modelProvider.card.usageLabel": "Utilisation", + "modelProvider.card.usagePriority": "Priorité d'utilisation", + "modelProvider.card.usagePriorityTip": "Définissez la ressource à utiliser en priorité lors de l'exécution des modèles.", "modelProvider.collapse": "Effondrer", "modelProvider.config": "Configuration", "modelProvider.configLoadBalancing": "Équilibrage de charge de configuration", @@ -387,9 +406,11 @@ "modelProvider.model": "Modèle", "modelProvider.modelAndParameters": "Modèle et Paramètres", "modelProvider.modelHasBeenDeprecated": "Ce modèle est obsolète", + "modelProvider.modelSettings": "Paramètres du modèle", "modelProvider.models": "Modèles", "modelProvider.modelsNum": "{{num}} Modèles", "modelProvider.noModelFound": "Aucun modèle trouvé pour {{model}}", + "modelProvider.noneConfigured": "Configurez un modèle système par défaut pour exécuter les applications", "modelProvider.notConfigured": "Le modèle du système n'a pas encore été entièrement configuré, et certaines fonctions peuvent être indisponibles.", "modelProvider.parameters": "PARAMÈTRES", "modelProvider.parametersInvalidRemoved": "Certains paramètres sont invalides et ont été supprimés.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Réinitialiser le {{date}}", "modelProvider.searchModel": "Modèle de recherche", "modelProvider.selectModel": "Sélectionnez votre modèle", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key indisponible", + "modelProvider.selector.apiKeyUnavailableTip": "L'API Key a été supprimée. Veuillez configurer une nouvelle API Key.", + "modelProvider.selector.configure": "Configurer", + "modelProvider.selector.configureRequired": "Configuration requise", + "modelProvider.selector.creditsExhausted": "Crédits épuisés", + "modelProvider.selector.creditsExhaustedTip": "Vos AI credits ont été épuisés. Veuillez mettre à niveau votre forfait ou ajouter une API Key.", + "modelProvider.selector.disabled": "Désactivé", + "modelProvider.selector.discoverMoreInMarketplace": "Découvrir plus sur le Marketplace", "modelProvider.selector.emptySetting": "Veuillez aller dans les paramètres pour configurer", "modelProvider.selector.emptyTip": "Aucun modèle disponible", + "modelProvider.selector.fromMarketplace": "Depuis le Marketplace", + "modelProvider.selector.incompatible": "Incompatible", + "modelProvider.selector.incompatibleTip": "Ce modèle n'est pas disponible dans la version actuelle. Veuillez sélectionner un autre modèle disponible.", + "modelProvider.selector.install": "Installer", + "modelProvider.selector.modelProviderSettings": "Paramètres du fournisseur de modèle", + "modelProvider.selector.noProviderConfigured": "Aucun fournisseur de modèle configuré", + "modelProvider.selector.noProviderConfiguredDesc": "Parcourez le Marketplace pour en installer un, ou configurez les fournisseurs dans les paramètres.", + "modelProvider.selector.onlyCompatibleModelsShown": "Seuls les modèles compatibles sont affichés", "modelProvider.selector.rerankTip": "Veuillez configurer le modèle Rerank", "modelProvider.selector.tip": "Ce modèle a été supprimé. Veuillez ajouter un modèle ou sélectionner un autre modèle.", "modelProvider.setupModelFirst": "Veuillez d'abord configurer votre modèle", diff --git a/web/i18n/fr-FR/plugin.json b/web/i18n/fr-FR/plugin.json index 79f43acb8e..cbea19646d 100644 --- a/web/i18n/fr-FR/plugin.json +++ b/web/i18n/fr-FR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Supprimer le plugin", "action.deleteContentLeft": "Souhaitez-vous supprimer", "action.deleteContentRight": "Plug-in ?", + "action.deleteSuccess": "Plugin supprimé avec succès", "action.pluginInfo": "Informations sur le plugin", "action.usedInApps": "Ce plugin est utilisé dans les applications {{num}}.", "allCategories": "Toutes les catégories", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Installer", "detailPanel.operation.remove": "Enlever", "detailPanel.operation.update": "Mettre à jour", + "detailPanel.operation.updateTooltip": "Mettez à jour pour accéder aux derniers modèles.", "detailPanel.operation.viewDetail": "Voir les détails", "detailPanel.serviceOk": "Service OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUS", @@ -231,12 +233,18 @@ "source.local": "Fichier de package local", "source.marketplace": "Marché", "task.clearAll": "Effacer tout", + "task.errorMsg.github": "Ce plugin n'a pas pu être installé automatiquement.\nVeuillez l'installer depuis GitHub.", + "task.errorMsg.marketplace": "Ce plugin n'a pas pu être installé automatiquement.\nVeuillez l'installer depuis le Marketplace.", + "task.errorMsg.unknown": "Ce plugin n'a pas pu être installé.\nLa source du plugin n'a pas pu être identifiée.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} les plugins n’ont pas pu être installés, cliquez pour voir", + "task.installFromGithub": "Installer depuis GitHub", + "task.installFromMarketplace": "Installer depuis le Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} les plugins n’ont pas pu être installés", "task.installing": "Installation des plugins.", + "task.installingHint": "Installation en cours... Cela peut prendre quelques minutes.", "task.installingWithError": "Installation des plugins {{installingLength}}, succès de {{successLength}}, échec de {{errorLength}}", "task.installingWithSuccess": "Installation des plugins {{installingLength}}, succès de {{successLength}}.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/fr-FR/workflow.json b/web/i18n/fr-FR/workflow.json index b5f67f74b3..09d140445e 100644 --- a/web/i18n/fr-FR/workflow.json +++ b/web/i18n/fr-FR/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "mise à jour du flux de travail", "error.startNodeRequired": "Veuillez d'abord ajouter un nœud de départ avant {{operation}}", "errorMsg.authRequired": "Autorisation requise", + "errorMsg.configureModel": "Configurez un modèle", "errorMsg.fieldRequired": "{{field}} est requis", "errorMsg.fields.code": "Code", "errorMsg.fields.model": "Modèle", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vision Variable", "errorMsg.invalidJson": "{{field}} est un JSON invalide", "errorMsg.invalidVariable": "Variable invalide", + "errorMsg.modelPluginNotInstalled": "Variable invalide. Configurez un modèle pour activer cette variable.", "errorMsg.noValidTool": "{{field}} aucun outil valide sélectionné", "errorMsg.rerankModelRequired": "Avant d’activer le modèle de reclassement, veuillez confirmer que le modèle a été correctement configuré dans les paramètres.", "errorMsg.startNodeRequired": "Veuillez d'abord ajouter un nœud de départ avant {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Taille de la fenêtre", "nodes.common.outputVars": "Variables de sortie", "nodes.common.pluginNotInstalled": "Le plugin n'est pas installé", + "nodes.common.pluginsNotInstalled": "{{count}} plugins non installés", "nodes.common.retry.maxRetries": "Nombre maximal de tentatives", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Tentatives", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Morceaux", "nodes.knowledgeBase.chunksInputTip": "La variable d'entrée du nœud de la base de connaissances est Chunks. Le type de variable est un objet avec un schéma JSON spécifique qui doit être cohérent avec la structure de morceau sélectionnée.", "nodes.knowledgeBase.chunksVariableIsRequired": "La variable Chunks est requise", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key indisponible", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Crédits épuisés", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatible", "nodes.knowledgeBase.embeddingModelIsInvalid": "Le modèle d'intégration est invalide", - "nodes.knowledgeBase.embeddingModelIsRequired": "Un modèle d'intégration est requis", + "nodes.knowledgeBase.embeddingModelIsRequired": "Un modèle d’intégration est requis", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modèle d’embedding non configuré", "nodes.knowledgeBase.indexMethodIsRequired": "La méthode d’indexation est requise", + "nodes.knowledgeBase.notConfigured": "Non configuré", "nodes.knowledgeBase.rerankingModelIsInvalid": "Le modèle de rerank est invalide", "nodes.knowledgeBase.rerankingModelIsRequired": "Un modèle de rerankage est requis", "nodes.knowledgeBase.retrievalSettingIsRequired": "Le paramètre de récupération est requis", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Prend en charge uniquement Jinja2", "nodes.templateTransform.inputVars": "Variables de saisie", "nodes.templateTransform.outputVars.output": "Contenu transformé", + "nodes.tool.authorizationRequired": "Autorisation requise", "nodes.tool.authorize": "Autoriser", "nodes.tool.inputVars": "Variables de saisie", "nodes.tool.insertPlaceholder1": "Tapez ou appuyez", @@ -1062,10 +1071,12 @@ "panel.change": "Modifier", "panel.changeBlock": "Changer de nœud", "panel.checklist": "Liste de contrôle", + "panel.checklistDescription": "Résolvez les problèmes suivants avant de publier", "panel.checklistResolved": "Tous les problèmes ont été résolus", "panel.checklistTip": "Assurez-vous que tous les problèmes sont résolus avant de publier", "panel.createdBy": "Créé par", "panel.goTo": "Aller à", + "panel.goToFix": "Aller corriger", "panel.helpLink": "Aide", "panel.maximize": "Maximiser le Canvas", "panel.minimize": "Sortir du mode plein écran", diff --git a/web/i18n/hi-IN/app-debug.json b/web/i18n/hi-IN/app-debug.json index 97e733e167..f9d438f0b1 100644 --- a/web/i18n/hi-IN/app-debug.json +++ b/web/i18n/hi-IN/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "चालू करें", "inputs.title": "डिबग और पूर्वावलोकन", "inputs.userInputField": "उपयोगकर्ता इनपुट फ़ील्ड", + "manageModels": "मॉडल प्रबंधित करें", "modelConfig.modeType.chat": "चैट", "modelConfig.modeType.completion": "पूर्ण", "modelConfig.model": "मॉडल", "modelConfig.setTone": "प्रतिक्रियाओं की टोन सेट करें", "modelConfig.title": "मॉडल और पैरामीटर", + "noModelProviderConfigured": "कोई मॉडल प्रदाता कॉन्फ़िगर नहीं है", + "noModelProviderConfiguredTip": "शुरू करने के लिए एक मॉडल प्रदाता इंस्टॉल या कॉन्फ़िगर करें।", + "noModelSelected": "कोई मॉडल चयनित नहीं है", + "noModelSelectedTip": "जारी रखने के लिए ऊपर एक मॉडल कॉन्फ़िगर करें।", "noResult": "प्रदर्शन यहाँ होगा।", "notSetAPIKey.description": "एलएलएम प्रदाता कुंजी सेट नहीं की गई है, और डीबग करने से पहले इसे सेट करने की आवश्यकता है।", "notSetAPIKey.settingBtn": "सेटिंग्स पर जाएं", diff --git a/web/i18n/hi-IN/common.json b/web/i18n/hi-IN/common.json index 75ac76add8..e61c96ca45 100644 --- a/web/i18n/hi-IN/common.json +++ b/web/i18n/hi-IN/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "अअनधिकारित", "modelProvider.buyQuota": "कोटा खरीदें", "modelProvider.callTimes": "कॉल समय", + "modelProvider.card.aiCreditsInUse": "AI credits उपयोग में हैं", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API key आवश्यक है", + "modelProvider.card.apiKeyUnavailableFallback": "API Key अनुपलब्ध, AI credits का उपयोग हो रहा है", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "वापस स्विच करने के लिए अपना API key कॉन्फ़िगरेशन जाँचें", "modelProvider.card.buyQuota": "कोटा खरीदें", "modelProvider.card.callTimes": "कॉल समय", + "modelProvider.card.creditsExhaustedDescription": "कृपया अपना प्लान अपग्रेड करें या API key कॉन्फ़िगर करें", + "modelProvider.card.creditsExhaustedFallback": "AI credits समाप्त, API Key का उपयोग हो रहा है", + "modelProvider.card.creditsExhaustedFallbackDescription": "AI credit प्राथमिकता फिर से शुरू करने के लिए अपना प्लान अपग्रेड करें।", + "modelProvider.card.creditsExhaustedMessage": "AI credits समाप्त हो गए हैं", "modelProvider.card.modelAPI": "{{modelName}} मॉडल API कुंजी का उपयोग कर रहे हैं।", "modelProvider.card.modelNotSupported": "{{modelName}} मॉडल इंस्टॉल नहीं हैं।", "modelProvider.card.modelSupported": "{{modelName}} मॉडल इस कोटा का उपयोग कर रहे हैं।", + "modelProvider.card.noApiKeysDescription": "अपने स्वयं के मॉडल क्रेडेंशियल का उपयोग शुरू करने के लिए API key जोड़ें।", + "modelProvider.card.noApiKeysFallback": "कोई API key नहीं, AI credits का उपयोग हो रहा है", + "modelProvider.card.noApiKeysTitle": "अभी तक कोई API key कॉन्फ़िगर नहीं है", + "modelProvider.card.noAvailableUsage": "कोई उपलब्ध उपयोग नहीं", "modelProvider.card.onTrial": "परीक्षण पर", "modelProvider.card.paid": "भुगतान किया हुआ", "modelProvider.card.priorityUse": "प्राथमिकता उपयोग", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "API कुंजी निकालें", "modelProvider.card.tip": "संदेश क्रेडिट {{modelNames}} के मॉडल का समर्थन करते हैं। भुगतान किए गए कोटा को प्राथमिकता दी जाएगी। भुगतान किए गए कोटा के समाप्त होने के बाद मुफ्त कोटा का उपयोग किया जाएगा।", "modelProvider.card.tokens": "टोकन", + "modelProvider.card.unavailable": "अनुपलब्ध", + "modelProvider.card.upgradePlan": "अपना प्लान अपग्रेड करें", + "modelProvider.card.usageLabel": "उपयोग", + "modelProvider.card.usagePriority": "उपयोग प्राथमिकता", + "modelProvider.card.usagePriorityTip": "मॉडल चलाते समय पहले किस संसाधन का उपयोग करना है, यह सेट करें।", "modelProvider.collapse": "संक्षिप्त करें", "modelProvider.config": "कॉन्फ़िग", "modelProvider.configLoadBalancing": "लोड बैलेंसिंग कॉन्फ़िग करें", @@ -387,9 +406,11 @@ "modelProvider.model": "मॉडल", "modelProvider.modelAndParameters": "मॉडल और पैरामीटर", "modelProvider.modelHasBeenDeprecated": "यह मॉडल अप्रचलित हो गया है", + "modelProvider.modelSettings": "मॉडल सेटिंग्स", "modelProvider.models": "मॉडल्स", "modelProvider.modelsNum": "{{num}} मॉडल्स", "modelProvider.noModelFound": "{{model}} के लिए कोई मॉडल नहीं मिला", + "modelProvider.noneConfigured": "एप्लिकेशन चलाने के लिए एक डिफ़ॉल्ट सिस्टम मॉडल कॉन्फ़िगर करें", "modelProvider.notConfigured": "सिस्टम मॉडल को अभी पूरी तरह से कॉन्फ़िगर नहीं किया गया है, और कुछ कार्य उपलब्ध नहीं हो सकते हैं।", "modelProvider.parameters": "पैरामीटर", "modelProvider.parametersInvalidRemoved": "कुछ पैरामीटर अमान्य हैं और हटा दिए गए हैं", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "{{date}} को रीसेट करें", "modelProvider.searchModel": "खोज मॉडल", "modelProvider.selectModel": "अपने मॉडल का चयन करें", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key अनुपलब्ध", + "modelProvider.selector.apiKeyUnavailableTip": "API key हटा दी गई है। कृपया एक नई API key कॉन्फ़िगर करें।", + "modelProvider.selector.configure": "कॉन्फ़िगर करें", + "modelProvider.selector.configureRequired": "कॉन्फ़िगरेशन आवश्यक", + "modelProvider.selector.creditsExhausted": "क्रेडिट समाप्त", + "modelProvider.selector.creditsExhaustedTip": "आपके AI credits समाप्त हो गए हैं। कृपया अपना प्लान अपग्रेड करें या API key जोड़ें।", + "modelProvider.selector.disabled": "अक्षम", + "modelProvider.selector.discoverMoreInMarketplace": "Marketplace में और खोजें", "modelProvider.selector.emptySetting": "कॉन्फ़िगर करने के लिए कृपया सेटिंग्स पर जाएं", "modelProvider.selector.emptyTip": "कोई उपलब्ध मॉडल नहीं", + "modelProvider.selector.fromMarketplace": "Marketplace से", + "modelProvider.selector.incompatible": "असंगत", + "modelProvider.selector.incompatibleTip": "यह मॉडल वर्तमान संस्करण में उपलब्ध नहीं है। कृपया कोई अन्य उपलब्ध मॉडल चुनें।", + "modelProvider.selector.install": "इंस्टॉल करें", + "modelProvider.selector.modelProviderSettings": "मॉडल प्रदाता सेटिंग्स", + "modelProvider.selector.noProviderConfigured": "कोई मॉडल प्रदाता कॉन्फ़िगर नहीं है", + "modelProvider.selector.noProviderConfiguredDesc": "इंस्टॉल करने के लिए Marketplace ब्राउज़ करें, या सेटिंग्स में प्रदाता कॉन्फ़िगर करें।", + "modelProvider.selector.onlyCompatibleModelsShown": "केवल संगत मॉडल दिखाए गए हैं", "modelProvider.selector.rerankTip": "कृपया रीरैंक मॉडल सेट करें", "modelProvider.selector.tip": "इस मॉडल को हटा दिया गया है। कृपया एक मॉडल जोड़ें या किसी अन्य मॉडल का चयन करें।", "modelProvider.setupModelFirst": "कृपया पहले अपना मॉडल सेट करें", diff --git a/web/i18n/hi-IN/plugin.json b/web/i18n/hi-IN/plugin.json index 5a16c2676e..8d3a1c2d1d 100644 --- a/web/i18n/hi-IN/plugin.json +++ b/web/i18n/hi-IN/plugin.json @@ -3,6 +3,7 @@ "action.delete": "प्लगइन हटाएं", "action.deleteContentLeft": "क्या आप हटाना चाहेंगे", "action.deleteContentRight": "प्लगइन?", + "action.deleteSuccess": "प्लगइन सफलतापूर्वक हटाया गया", "action.pluginInfo": "प्लगइन जानकारी", "action.usedInApps": "यह प्लगइन {{num}} ऐप्स में उपयोग किया जा रहा है।", "allCategories": "सभी श्रेणियाँ", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "स्थापित करें", "detailPanel.operation.remove": "हटाएं", "detailPanel.operation.update": "अपडेट", + "detailPanel.operation.updateTooltip": "नवीनतम मॉडल तक पहुँचने के लिए अपडेट करें।", "detailPanel.operation.viewDetail": "विवरण देखें", "detailPanel.serviceOk": "सेवा ठीक है", "detailPanel.strategyNum": "{{num}} {{strategy}} शामिल", @@ -231,12 +233,18 @@ "source.local": "स्थानीय पैकेज फ़ाइल", "source.marketplace": "बाजार", "task.clearAll": "सभी साफ करें", + "task.errorMsg.github": "यह प्लगइन स्वचालित रूप से इंस्टॉल नहीं हो सका।\nकृपया इसे GitHub से इंस्टॉल करें।", + "task.errorMsg.marketplace": "यह प्लगइन स्वचालित रूप से इंस्टॉल नहीं हो सका।\nकृपया इसे Marketplace से इंस्टॉल करें।", + "task.errorMsg.unknown": "यह प्लगइन इंस्टॉल नहीं हो सका।\nप्लगइन स्रोत की पहचान नहीं हो सकी।", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} प्लगइन्स स्थापित करने में विफल रहे, देखने के लिए क्लिक करें", + "task.installFromGithub": "GitHub से इंस्टॉल करें", + "task.installFromMarketplace": "Marketplace से इंस्टॉल करें", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} प्लगइन्स स्थापित करने में विफल रहे", "task.installing": "प्लगइन्स स्थापित कर रहे हैं।", + "task.installingHint": "इंस्टॉल हो रहा है... इसमें कुछ मिनट लग सकते हैं।", "task.installingWithError": "{{installingLength}} प्लगइन्स स्थापित कर रहे हैं, {{successLength}} सफल, {{errorLength}} विफल", "task.installingWithSuccess": "{{installingLength}} प्लगइन्स स्थापित कर रहे हैं, {{successLength}} सफल।", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/hi-IN/workflow.json b/web/i18n/hi-IN/workflow.json index 31a3d4627f..fb9256536e 100644 --- a/web/i18n/hi-IN/workflow.json +++ b/web/i18n/hi-IN/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "वर्कफ़्लो को अपडेट करना", "error.startNodeRequired": "कृपया {{operation}} से पहले एक प्रारंभ नोड जोड़ें", "errorMsg.authRequired": "प्राधिकरण आवश्यक है", + "errorMsg.configureModel": "एक मॉडल कॉन्फ़िगर करें", "errorMsg.fieldRequired": "{{field}} आवश्यक है", "errorMsg.fields.code": "कोड", "errorMsg.fields.model": "मॉडल", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "दृष्टि चर", "errorMsg.invalidJson": "{{field}} अमान्य JSON है", "errorMsg.invalidVariable": "अमान्य वेरिएबल", + "errorMsg.modelPluginNotInstalled": "अमान्य वेरिएबल। इस वेरिएबल को सक्षम करने के लिए एक मॉडल कॉन्फ़िगर करें।", "errorMsg.noValidTool": "{{field}} कोई मान्य उपकरण चयनित नहीं किया गया", "errorMsg.rerankModelRequired": "Rerank मॉडल चालू करने से पहले, कृपया पुष्टि करें कि मॉडल को सेटिंग्स में सफलतापूर्वक कॉन्फ़िगर किया गया है।", "errorMsg.startNodeRequired": "कृपया {{operation}} से पहले पहले एक स्टार्ट नोड जोड़ें", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "विंडो साइज", "nodes.common.outputVars": "आउटपुट वेरिएबल्स", "nodes.common.pluginNotInstalled": "प्लगइन इंस्टॉल नहीं है", + "nodes.common.pluginsNotInstalled": "{{count}} प्लगइन इंस्टॉल नहीं हैं", "nodes.common.retry.maxRetries": "अधिकतम पुनः प्रयास करता है", "nodes.common.retry.ms": "सुश्री", "nodes.common.retry.retries": "{{num}} पुनर्प्रयास", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "टुकड़े", "nodes.knowledgeBase.chunksInputTip": "ज्ञान आधार नोड का इनपुट वेरिएबल टुकड़े है। वेरिएबल प्रकार एक ऑब्जेक्ट है जिसमें एक विशेष JSON स्कीमा है जो चयनित चंक संरचना के साथ सुसंगत होना चाहिए।", "nodes.knowledgeBase.chunksVariableIsRequired": "टुकड़े चर आवश्यक है", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API key अनुपलब्ध", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "क्रेडिट समाप्त", + "nodes.knowledgeBase.embeddingModelIncompatible": "असंगत", "nodes.knowledgeBase.embeddingModelIsInvalid": "एम्बेडिंग मॉडल अमान्य है", "nodes.knowledgeBase.embeddingModelIsRequired": "एम्बेडिंग मॉडल आवश्यक है", + "nodes.knowledgeBase.embeddingModelNotConfigured": "एम्बेडिंग मॉडल कॉन्फ़िगर नहीं है", "nodes.knowledgeBase.indexMethodIsRequired": "सूची विधि आवश्यक है", + "nodes.knowledgeBase.notConfigured": "कॉन्फ़िगर नहीं है", "nodes.knowledgeBase.rerankingModelIsInvalid": "पुनः क्रमांकन मॉडल अमान्य है", "nodes.knowledgeBase.rerankingModelIsRequired": "पुनः क्रमांकन मॉडल की आवश्यकता है", "nodes.knowledgeBase.retrievalSettingIsRequired": "पुनप्राप्ति सेटिंग आवश्यक है", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "केवल Jinja2 का समर्थन करता है", "nodes.templateTransform.inputVars": "इनपुट वेरिएबल्स", "nodes.templateTransform.outputVars.output": "रूपांतरित सामग्री", + "nodes.tool.authorizationRequired": "प्राधिकरण आवश्यक", "nodes.tool.authorize": "अधिकृत करें", "nodes.tool.inputVars": "इनपुट वेरिएबल्स", "nodes.tool.insertPlaceholder1": "टाइप करें या दबाएँ", @@ -1062,10 +1071,12 @@ "panel.change": "बदलें", "panel.changeBlock": "नोड बदलें", "panel.checklist": "चेकलिस्ट", + "panel.checklistDescription": "प्रकाशित करने से पहले निम्नलिखित समस्याएँ हल करें", "panel.checklistResolved": "सभी समस्याएं हल हो गई हैं", "panel.checklistTip": "प्रकाशित करने से पहले सुनिश्चित करें कि सभी समस्याएं हल हो गई हैं", "panel.createdBy": "द्वारा बनाया गया ", "panel.goTo": "जाओ", + "panel.goToFix": "ठीक करने जाएँ", "panel.helpLink": "सहायता", "panel.maximize": "कैनवास का अधिकतम लाभ उठाएँ", "panel.minimize": "पूर्ण स्क्रीन से बाहर निकलें", diff --git a/web/i18n/id-ID/app-debug.json b/web/i18n/id-ID/app-debug.json index 21651349ba..feb0f0a296 100644 --- a/web/i18n/id-ID/app-debug.json +++ b/web/i18n/id-ID/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "Jalankan", "inputs.title": "Debug & Pratinjau", "inputs.userInputField": "Bidang Input Pengguna", + "manageModels": "Kelola model", "modelConfig.modeType.chat": "Mengobrol", "modelConfig.modeType.completion": "Lengkap", "modelConfig.model": "Pola", "modelConfig.setTone": "Atur nada respons", "modelConfig.title": "Model dan Parameter", + "noModelProviderConfigured": "Belum ada penyedia model yang dikonfigurasi", + "noModelProviderConfiguredTip": "Instal atau konfigurasi penyedia model untuk memulai.", + "noModelSelected": "Belum ada model yang dipilih", + "noModelSelectedTip": "Konfigurasi model di atas untuk melanjutkan.", "noResult": "Output akan ditampilkan di sini.", "notSetAPIKey.description": "Kunci penyedia LLM belum diatur, dan perlu diatur sebelum debugging.", "notSetAPIKey.settingBtn": "Buka pengaturan", diff --git a/web/i18n/id-ID/common.json b/web/i18n/id-ID/common.json index 90531af712..51cd429992 100644 --- a/web/i18n/id-ID/common.json +++ b/web/i18n/id-ID/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Sah", "modelProvider.buyQuota": "Beli Kuota", "modelProvider.callTimes": "Waktu panggilan", + "modelProvider.card.aiCreditsInUse": "Kredit AI sedang digunakan", + "modelProvider.card.aiCreditsOption": "Kredit AI", + "modelProvider.card.apiKeyOption": "Kunci API", + "modelProvider.card.apiKeyRequired": "Kunci API diperlukan", + "modelProvider.card.apiKeyUnavailableFallback": "Kunci API tidak tersedia, kini menggunakan kredit AI", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Periksa konfigurasi kunci API Anda untuk beralih kembali", "modelProvider.card.buyQuota": "Beli Kuota", "modelProvider.card.callTimes": "Waktu panggilan", + "modelProvider.card.creditsExhaustedDescription": "Silakan tingkatkan paket Anda atau konfigurasikan kunci API", + "modelProvider.card.creditsExhaustedFallback": "Kredit AI habis, kini menggunakan kunci API", + "modelProvider.card.creditsExhaustedFallbackDescription": "Tingkatkan paket Anda untuk melanjutkan prioritas kredit AI.", + "modelProvider.card.creditsExhaustedMessage": "Kredit AI telah habis", "modelProvider.card.modelAPI": "Model {{modelName}} menggunakan Kunci API.", "modelProvider.card.modelNotSupported": "Model {{modelName}} tidak terpasang.", "modelProvider.card.modelSupported": "Model {{modelName}} menggunakan kuota ini.", + "modelProvider.card.noApiKeysDescription": "Tambahkan kunci API untuk mulai menggunakan kredensial model Anda sendiri.", + "modelProvider.card.noApiKeysFallback": "Tidak ada kunci API, menggunakan kredit AI sebagai gantinya", + "modelProvider.card.noApiKeysTitle": "Belum ada kunci API yang dikonfigurasi", + "modelProvider.card.noAvailableUsage": "Tidak ada penggunaan yang tersedia", "modelProvider.card.onTrial": "Sedang Diadili", "modelProvider.card.paid": "Dibayar", "modelProvider.card.priorityUse": "Penggunaan prioritas", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Menghapus Kunci API", "modelProvider.card.tip": "Kredit pesan mendukung model dari {{modelNames}}. Prioritas akan diberikan pada kuota yang dibayarkan. Kuota gratis akan digunakan setelah kuota yang dibayarkan habis.", "modelProvider.card.tokens": "Token", + "modelProvider.card.unavailable": "Tidak tersedia", + "modelProvider.card.upgradePlan": "tingkatkan paket Anda", + "modelProvider.card.usageLabel": "Penggunaan", + "modelProvider.card.usagePriority": "Prioritas Penggunaan", + "modelProvider.card.usagePriorityTip": "Tentukan sumber daya mana yang digunakan terlebih dahulu saat menjalankan model.", "modelProvider.collapse": "Roboh", "modelProvider.config": "Konfigurasi", "modelProvider.configLoadBalancing": "Penyeimbangan Beban Konfigurasi", @@ -387,9 +406,11 @@ "modelProvider.model": "Pola", "modelProvider.modelAndParameters": "Model dan Parameter", "modelProvider.modelHasBeenDeprecated": "Model ini tidak digunakan lagi", + "modelProvider.modelSettings": "Pengaturan Model", "modelProvider.models": "Model", "modelProvider.modelsNum": "Model {{num}}", "modelProvider.noModelFound": "Tidak ditemukan model untuk {{model}}", + "modelProvider.noneConfigured": "Konfigurasikan model sistem default untuk menjalankan aplikasi", "modelProvider.notConfigured": "Model sistem belum sepenuhnya dikonfigurasi", "modelProvider.parameters": "PARAMETER", "modelProvider.parametersInvalidRemoved": "Beberapa parameter tidak valid dan telah dihapus", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Setel ulang pada {{date}}", "modelProvider.searchModel": "Model pencarian", "modelProvider.selectModel": "Pilih model Anda", + "modelProvider.selector.aiCredits": "Kredit AI", + "modelProvider.selector.apiKeyUnavailable": "Kunci API tidak tersedia", + "modelProvider.selector.apiKeyUnavailableTip": "Kunci API telah dihapus. Silakan konfigurasikan kunci API baru.", + "modelProvider.selector.configure": "Konfigurasikan", + "modelProvider.selector.configureRequired": "Konfigurasi diperlukan", + "modelProvider.selector.creditsExhausted": "Kredit habis", + "modelProvider.selector.creditsExhaustedTip": "Kredit AI Anda telah habis. Silakan tingkatkan paket Anda atau tambahkan kunci API.", + "modelProvider.selector.disabled": "Dinonaktifkan", + "modelProvider.selector.discoverMoreInMarketplace": "Temukan lebih banyak di Marketplace", "modelProvider.selector.emptySetting": "Silakan buka pengaturan untuk mengonfigurasi", "modelProvider.selector.emptyTip": "Tidak ada model yang tersedia", + "modelProvider.selector.fromMarketplace": "Dari Marketplace", + "modelProvider.selector.incompatible": "Tidak kompatibel", + "modelProvider.selector.incompatibleTip": "Model ini tidak tersedia dalam versi saat ini. Silakan pilih model lain yang tersedia.", + "modelProvider.selector.install": "Instal", + "modelProvider.selector.modelProviderSettings": "Pengaturan Penyedia Model", + "modelProvider.selector.noProviderConfigured": "Belum ada penyedia model yang dikonfigurasi", + "modelProvider.selector.noProviderConfiguredDesc": "Jelajahi Marketplace untuk menginstal, atau konfigurasikan penyedia di pengaturan.", + "modelProvider.selector.onlyCompatibleModelsShown": "Hanya model yang kompatibel yang ditampilkan", "modelProvider.selector.rerankTip": "Silakan atur model Rerank", "modelProvider.selector.tip": "Model ini telah dihapus. Silakan tambahkan model atau pilih model lain.", "modelProvider.setupModelFirst": "Silakan atur model Anda terlebih dahulu", diff --git a/web/i18n/id-ID/plugin.json b/web/i18n/id-ID/plugin.json index 486de762f8..6b9c8f51af 100644 --- a/web/i18n/id-ID/plugin.json +++ b/web/i18n/id-ID/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Hapus plugin", "action.deleteContentLeft": "Apakah Anda ingin menghapus", "action.deleteContentRight": "Plugin?", + "action.deleteSuccess": "Plugin berhasil dihapus", "action.pluginInfo": "Info plugin", "action.usedInApps": "Plugin ini digunakan di {{num}} aplikasi.", "allCategories": "Semua Kategori", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Pasang", "detailPanel.operation.remove": "Hapus", "detailPanel.operation.update": "Pemutakhiran", + "detailPanel.operation.updateTooltip": "Perbarui untuk mengakses model terbaru.", "detailPanel.operation.viewDetail": "Lihat Detail", "detailPanel.serviceOk": "Layanan OK", "detailPanel.strategyNum": "{{num}} {{strategy}} TERMASUK", @@ -231,12 +233,18 @@ "source.local": "File Paket Lokal", "source.marketplace": "Pasar", "task.clearAll": "Hapus semua", + "task.errorMsg.github": "Plugin ini tidak terinstal secara otomatis.\nSilakan instal dari GitHub.", + "task.errorMsg.marketplace": "Plugin ini tidak terinstal secara otomatis.\nSilakan instal dari Marketplace.", + "task.errorMsg.unknown": "Plugin ini tidak terinstal.\nSumber plugin tidak dapat diidentifikasi.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Gagal menginstal plugin {{errorLength}}, klik untuk melihat", + "task.installFromGithub": "Instal dari GitHub", + "task.installFromMarketplace": "Instal dari Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Gagal menginstal {{errorLength}} plugin", "task.installing": "Memasang plugin.", + "task.installingHint": "Menginstal... Ini mungkin memerlukan beberapa menit.", "task.installingWithError": "Memasang {{installingLength}} plugin, {{successLength}} berhasil, {{errorLength}} gagal", "task.installingWithSuccess": "Memasang plugin {{installingLength}}, {{successLength}} berhasil.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/id-ID/workflow.json b/web/i18n/id-ID/workflow.json index 4f8e01b7f1..76e80be7d7 100644 --- a/web/i18n/id-ID/workflow.json +++ b/web/i18n/id-ID/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "memperbarui alur kerja", "error.startNodeRequired": "Silakan tambahkan node awal terlebih dahulu sebelum {{operation}}", "errorMsg.authRequired": "Otorisasi diperlukan", + "errorMsg.configureModel": "Konfigurasikan model", "errorMsg.fieldRequired": "{{field}} wajib diisi", "errorMsg.fields.code": "Kode", "errorMsg.fields.model": "Model", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Variabel Penglihatan", "errorMsg.invalidJson": "{{field}} adalah JSON yang tidak valid", "errorMsg.invalidVariable": "Variabel tidak valid", + "errorMsg.modelPluginNotInstalled": "Variabel tidak valid. Konfigurasikan model untuk mengaktifkan variabel ini.", "errorMsg.noValidTool": "{{field}} tidak ada alat yang valid dipilih", "errorMsg.rerankModelRequired": "Model Rerank yang dikonfigurasi diperlukan", "errorMsg.startNodeRequired": "Silakan tambahkan node awal terlebih dahulu sebelum {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Ukuran Jendela", "nodes.common.outputVars": "Variabel Keluaran", "nodes.common.pluginNotInstalled": "Plugin tidak terpasang", + "nodes.common.pluginsNotInstalled": "{{count}} plugin tidak terpasang", "nodes.common.retry.maxRetries": "percobaan ulang maks", "nodes.common.retry.ms": "Ms", "nodes.common.retry.retries": "{{num}} Percobaan Ulang", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Potongan", "nodes.knowledgeBase.chunksInputTip": "Variabel input dari node basis pengetahuan adalah Chunks. Tipe variabel adalah objek dengan Skema JSON tertentu yang harus konsisten dengan struktur chunk yang dipilih.", "nodes.knowledgeBase.chunksVariableIsRequired": "Variabel Chunks diperlukan", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "Kunci API tidak tersedia", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Kredit habis", + "nodes.knowledgeBase.embeddingModelIncompatible": "Tidak kompatibel", "nodes.knowledgeBase.embeddingModelIsInvalid": "Model embedding tidak valid", "nodes.knowledgeBase.embeddingModelIsRequired": "Model embedding diperlukan", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Model embedding belum dikonfigurasi", "nodes.knowledgeBase.indexMethodIsRequired": "Metode indeks diperlukan", + "nodes.knowledgeBase.notConfigured": "Belum dikonfigurasi", "nodes.knowledgeBase.rerankingModelIsInvalid": "Model reranking tidak valid", "nodes.knowledgeBase.rerankingModelIsRequired": "Model reranking diperlukan", "nodes.knowledgeBase.retrievalSettingIsRequired": "Pengaturan pengambilan diperlukan", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Hanya mendukung Jinja2", "nodes.templateTransform.inputVars": "Variabel Masukan", "nodes.templateTransform.outputVars.output": "Konten yang diubah", + "nodes.tool.authorizationRequired": "Otorisasi diperlukan", "nodes.tool.authorize": "Otorisasi", "nodes.tool.inputVars": "Variabel Masukan", "nodes.tool.insertPlaceholder1": "Ketik atau tekan", @@ -1062,10 +1071,12 @@ "panel.change": "Ubah", "panel.changeBlock": "Ubah Node", "panel.checklist": "Checklist", + "panel.checklistDescription": "Selesaikan masalah berikut sebelum menerbitkan", "panel.checklistResolved": "Semua masalah terselesaikan", "panel.checklistTip": "Pastikan semua masalah diselesaikan sebelum dipublikasikan", "panel.createdBy": "Dibuat oleh", "panel.goTo": "Pergi ke", + "panel.goToFix": "Pergi ke perbaikan", "panel.helpLink": "Docs", "panel.maximize": "Maksimalkan Kanvas", "panel.minimize": "Keluar dari Layar Penuh", diff --git a/web/i18n/it-IT/app-debug.json b/web/i18n/it-IT/app-debug.json index 94e4c00894..b9aad1ff6b 100644 --- a/web/i18n/it-IT/app-debug.json +++ b/web/i18n/it-IT/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "ESEGUI", "inputs.title": "Debug e Anteprima", "inputs.userInputField": "Campo Input Utente", + "manageModels": "Gestisci modelli", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Completamento", "modelConfig.model": "Modello", "modelConfig.setTone": "Imposta tono delle risposte", "modelConfig.title": "Modello e Parametri", + "noModelProviderConfigured": "Nessun fornitore di modelli configurato", + "noModelProviderConfiguredTip": "Installa o configura un fornitore di modelli per iniziare.", + "noModelSelected": "Nessun modello selezionato", + "noModelSelectedTip": "Configura un modello sopra per continuare.", "noResult": "L'output verrà visualizzato qui.", "notSetAPIKey.description": "La chiave del provider LLM non è stata impostata e deve essere impostata prima del debug.", "notSetAPIKey.settingBtn": "Vai alle impostazioni", diff --git a/web/i18n/it-IT/common.json b/web/i18n/it-IT/common.json index d7b57d6d08..283c090ea8 100644 --- a/web/i18n/it-IT/common.json +++ b/web/i18n/it-IT/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Non autorizzato", "modelProvider.buyQuota": "Acquista Quota", "modelProvider.callTimes": "Numero di chiamate", + "modelProvider.card.aiCreditsInUse": "Crediti AI in uso", + "modelProvider.card.aiCreditsOption": "Crediti AI", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API Key richiesta", + "modelProvider.card.apiKeyUnavailableFallback": "API Key non disponibile, utilizzo dei crediti AI in corso", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Controlla la configurazione della tua API Key per tornare indietro", "modelProvider.card.buyQuota": "Acquista Quota", "modelProvider.card.callTimes": "Numero di chiamate", + "modelProvider.card.creditsExhaustedDescription": "Aggiorna il tuo piano o configura una API Key", + "modelProvider.card.creditsExhaustedFallback": "Crediti AI esauriti, utilizzo della API Key in corso", + "modelProvider.card.creditsExhaustedFallbackDescription": "Aggiorna il tuo piano per ripristinare la priorità dei crediti AI.", + "modelProvider.card.creditsExhaustedMessage": "I crediti AI sono esauriti", "modelProvider.card.modelAPI": "I modelli {{modelName}} stanno utilizzando la chiave API.", "modelProvider.card.modelNotSupported": "I modelli {{modelName}} non sono installati.", "modelProvider.card.modelSupported": "I modelli {{modelName}} stanno utilizzando questa quota.", + "modelProvider.card.noApiKeysDescription": "Aggiungi una API Key per iniziare a usare le tue credenziali del modello.", + "modelProvider.card.noApiKeysFallback": "Nessuna API Key, utilizzo dei crediti AI in corso", + "modelProvider.card.noApiKeysTitle": "Nessuna API Key ancora configurata", + "modelProvider.card.noAvailableUsage": "Nessun utilizzo disponibile", "modelProvider.card.onTrial": "In Prova", "modelProvider.card.paid": "Pagato", "modelProvider.card.priorityUse": "Uso prioritario", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Rimuovi API Key", "modelProvider.card.tip": "I crediti di messaggi supportano modelli di {{modelNames}}. Verrà data priorità alla quota pagata. La quota gratuita sarà utilizzata dopo l'esaurimento della quota pagata.", "modelProvider.card.tokens": "Token", + "modelProvider.card.unavailable": "Non disponibile", + "modelProvider.card.upgradePlan": "aggiorna il tuo piano", + "modelProvider.card.usageLabel": "Utilizzo", + "modelProvider.card.usagePriority": "Priorità di utilizzo", + "modelProvider.card.usagePriorityTip": "Imposta quale risorsa utilizzare per prima durante l'esecuzione dei modelli.", "modelProvider.collapse": "Comprimi", "modelProvider.config": "Configura", "modelProvider.configLoadBalancing": "Configura Bilanciamento del Carico", @@ -387,9 +406,11 @@ "modelProvider.model": "Modello", "modelProvider.modelAndParameters": "Modello e Parametri", "modelProvider.modelHasBeenDeprecated": "Questo modello è stato deprecato", + "modelProvider.modelSettings": "Impostazioni modello", "modelProvider.models": "Modelli", "modelProvider.modelsNum": "{{num}} Modelli", "modelProvider.noModelFound": "Nessun modello trovato per {{model}}", + "modelProvider.noneConfigured": "Configura un modello di sistema predefinito per eseguire le applicazioni", "modelProvider.notConfigured": "Il modello di sistema non è ancora stato completamente configurato e alcune funzioni potrebbero non essere disponibili.", "modelProvider.parameters": "PARAMETRI", "modelProvider.parametersInvalidRemoved": "Alcuni parametri non sono validi e sono stati rimossi.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Ripristina il {{date}}", "modelProvider.searchModel": "Modello di ricerca", "modelProvider.selectModel": "Seleziona il tuo modello", + "modelProvider.selector.aiCredits": "Crediti AI", + "modelProvider.selector.apiKeyUnavailable": "API Key non disponibile", + "modelProvider.selector.apiKeyUnavailableTip": "La API Key è stata rimossa. Configura una nuova API Key.", + "modelProvider.selector.configure": "Configura", + "modelProvider.selector.configureRequired": "Configurazione richiesta", + "modelProvider.selector.creditsExhausted": "Crediti esauriti", + "modelProvider.selector.creditsExhaustedTip": "I tuoi crediti AI sono esauriti. Aggiorna il tuo piano o aggiungi una API Key.", + "modelProvider.selector.disabled": "Disabilitato", + "modelProvider.selector.discoverMoreInMarketplace": "Scopri di più nel Marketplace", "modelProvider.selector.emptySetting": "Per favore vai alle impostazioni per configurare", "modelProvider.selector.emptyTip": "Nessun modello disponibile", + "modelProvider.selector.fromMarketplace": "Dal Marketplace", + "modelProvider.selector.incompatible": "Incompatibile", + "modelProvider.selector.incompatibleTip": "Questo modello non è disponibile nella versione corrente. Seleziona un altro modello disponibile.", + "modelProvider.selector.install": "Installa", + "modelProvider.selector.modelProviderSettings": "Impostazioni fornitore modelli", + "modelProvider.selector.noProviderConfigured": "Nessun fornitore di modelli configurato", + "modelProvider.selector.noProviderConfiguredDesc": "Cerca nel Marketplace per installarne uno o configura i fornitori nelle impostazioni.", + "modelProvider.selector.onlyCompatibleModelsShown": "Vengono mostrati solo i modelli compatibili", "modelProvider.selector.rerankTip": "Per favore, configura il modello di Rerank", "modelProvider.selector.tip": "Questo modello è stato rimosso. Per favore aggiungi un modello o seleziona un altro modello.", "modelProvider.setupModelFirst": "Per favore, configura prima il tuo modello", diff --git a/web/i18n/it-IT/plugin.json b/web/i18n/it-IT/plugin.json index 06a2dfa4ab..296aa31d54 100644 --- a/web/i18n/it-IT/plugin.json +++ b/web/i18n/it-IT/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Rimuovi plugin", "action.deleteContentLeft": "Vorresti rimuovere", "action.deleteContentRight": "plugin?", + "action.deleteSuccess": "Plugin rimosso con successo", "action.pluginInfo": "Informazioni sul plugin", "action.usedInApps": "Questo plugin viene utilizzato nelle app {{num}}.", "allCategories": "Tutte le categorie", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Installare", "detailPanel.operation.remove": "Togliere", "detailPanel.operation.update": "Aggiornare", + "detailPanel.operation.updateTooltip": "Aggiorna per accedere ai modelli più recenti.", "detailPanel.operation.viewDetail": "vedi dettagli", "detailPanel.serviceOk": "Servizio OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUSO", @@ -231,12 +233,18 @@ "source.local": "File del pacchetto locale", "source.marketplace": "Mercato", "task.clearAll": "Cancella tutto", + "task.errorMsg.github": "Impossibile installare automaticamente questo plugin.\nInstallalo da GitHub.", + "task.errorMsg.marketplace": "Impossibile installare automaticamente questo plugin.\nInstallalo dal Marketplace.", + "task.errorMsg.unknown": "Impossibile installare questo plugin.\nNon è stato possibile identificare la fonte del plugin.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Impossibile installare i plugin {{errorLength}}, clicca per visualizzare", + "task.installFromGithub": "Installa da GitHub", + "task.installFromMarketplace": "Installa dal Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Impossibile installare i plugin di {{errorLength}}", "task.installing": "Installazione dei plugin.", + "task.installingHint": "Installazione in corso... Potrebbe richiedere qualche minuto.", "task.installingWithError": "Installazione dei plugin {{installingLength}}, {{successLength}} successo, {{errorLength}} fallito", "task.installingWithSuccess": "Installazione dei plugin {{installingLength}}, {{successLength}} successo.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/it-IT/workflow.json b/web/i18n/it-IT/workflow.json index 0184b85009..519d7d7e2a 100644 --- a/web/i18n/it-IT/workflow.json +++ b/web/i18n/it-IT/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "aggiornamento del flusso di lavoro", "error.startNodeRequired": "Per favore aggiungi prima un nodo iniziale prima di {{operation}}", "errorMsg.authRequired": "È richiesta l'autorizzazione", + "errorMsg.configureModel": "Configura un modello", "errorMsg.fieldRequired": "{{field}} è richiesto", "errorMsg.fields.code": "Codice", "errorMsg.fields.model": "Modello", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Visione variabile", "errorMsg.invalidJson": "{{field}} è un JSON non valido", "errorMsg.invalidVariable": "Variabile non valida", + "errorMsg.modelPluginNotInstalled": "Variabile non valida. Configura un modello per abilitare questa variabile.", "errorMsg.noValidTool": "{{field}} nessuno strumento valido selezionato", "errorMsg.rerankModelRequired": "Prima di attivare il modello di reranking, conferma che il modello è stato configurato correttamente nelle impostazioni.", "errorMsg.startNodeRequired": "Per favore aggiungi prima un nodo iniziale prima di {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Dimensione Finestra", "nodes.common.outputVars": "Variabili di Output", "nodes.common.pluginNotInstalled": "Il plugin non è installato", + "nodes.common.pluginsNotInstalled": "{{count}} plugin non installati", "nodes.common.retry.maxRetries": "Numero massimo di tentativi", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Tentativi", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Pezzetti", "nodes.knowledgeBase.chunksInputTip": "La variabile di input del nodo della base di conoscenza è Chunks. Il tipo di variabile è un oggetto con uno specifico schema JSON che deve essere coerente con la struttura del chunk selezionato.", "nodes.knowledgeBase.chunksVariableIsRequired": "La variabile Chunks è richiesta", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key non disponibile", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Crediti esauriti", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatibile", "nodes.knowledgeBase.embeddingModelIsInvalid": "Il modello di embedding non è valido", "nodes.knowledgeBase.embeddingModelIsRequired": "È necessario un modello di embedding", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modello di embedding non configurato", "nodes.knowledgeBase.indexMethodIsRequired": "È necessario il metodo dell'indice", + "nodes.knowledgeBase.notConfigured": "Non configurato", "nodes.knowledgeBase.rerankingModelIsInvalid": "Il modello di riorganizzazione è non valido", "nodes.knowledgeBase.rerankingModelIsRequired": "È richiesto un modello di riordinamento", "nodes.knowledgeBase.retrievalSettingIsRequired": "È richiesta l'impostazione di recupero", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Supporta solo Jinja2", "nodes.templateTransform.inputVars": "Variabili di Input", "nodes.templateTransform.outputVars.output": "Contenuto trasformato", + "nodes.tool.authorizationRequired": "Autorizzazione richiesta", "nodes.tool.authorize": "Autorizza", "nodes.tool.inputVars": "Variabili di Input", "nodes.tool.insertPlaceholder1": "Digita o premi", @@ -1062,10 +1071,12 @@ "panel.change": "Cambia", "panel.changeBlock": "Cambia Nodo", "panel.checklist": "Checklist", + "panel.checklistDescription": "Risolvi i seguenti problemi prima della pubblicazione", "panel.checklistResolved": "Tutti i problemi sono risolti", "panel.checklistTip": "Assicurati che tutti i problemi siano risolti prima di pubblicare", "panel.createdBy": "Creato da ", "panel.goTo": "Vai a", + "panel.goToFix": "Vai a correggere", "panel.helpLink": "Aiuto", "panel.maximize": "Massimizza Canvas", "panel.minimize": "Esci dalla modalità schermo intero", diff --git a/web/i18n/ja-JP/common.json b/web/i18n/ja-JP/common.json index 98fe556ecf..a65d8e933c 100644 --- a/web/i18n/ja-JP/common.json +++ b/web/i18n/ja-JP/common.json @@ -358,6 +358,7 @@ "modelProvider.card.noApiKeysDescription": "独自のモデル認証情報を使用するには、API キーを追加してください。", "modelProvider.card.noApiKeysFallback": "API キーが未設定のため、AI クレジットを使用しています", "modelProvider.card.noApiKeysTitle": "API キーはまだ設定されていません", + "modelProvider.card.noAvailableUsage": "利用可能な使用量がありません", "modelProvider.card.onTrial": "トライアル中", "modelProvider.card.paid": "有料", "modelProvider.card.priorityUse": "優先利用", @@ -366,6 +367,9 @@ "modelProvider.card.removeKey": "API キーを削除", "modelProvider.card.tip": "メッセージ枠は{{modelNames}}のモデルを使用することをサポートしています。無料枠は有料枠が使い果たされた後に消費されます。", "modelProvider.card.tokens": "トークン", + "modelProvider.card.unavailable": "利用不可", + "modelProvider.card.upgradePlan": "プランをアップグレード", + "modelProvider.card.usageLabel": "使用量", "modelProvider.card.usagePriority": "使用優先順位", "modelProvider.card.usagePriorityTip": "モデル実行時に優先して使用するリソースを設定します。", "modelProvider.collapse": "折り畳み", @@ -402,9 +406,11 @@ "modelProvider.model": "モデル", "modelProvider.modelAndParameters": "モデルとパラメータ", "modelProvider.modelHasBeenDeprecated": "このモデルは廃止予定です", + "modelProvider.modelSettings": "モデル設定", "modelProvider.models": "モデル", "modelProvider.modelsNum": "{{num}}のモデル", "modelProvider.noModelFound": "{{model}}に対するモデルが見つかりません", + "modelProvider.noneConfigured": "アプリケーションを実行するためにデフォルトのシステムモデルを設定してください", "modelProvider.notConfigured": "システムモデルがまだ完全に設定されておらず、一部の機能が利用できない場合があります。", "modelProvider.parameters": "パラメータ", "modelProvider.parametersInvalidRemoved": "いくつかのパラメータが無効であり、削除されました。", @@ -421,14 +427,21 @@ "modelProvider.selector.aiCredits": "AI クレジット", "modelProvider.selector.apiKeyUnavailable": "API キーが利用できません", "modelProvider.selector.apiKeyUnavailableTip": "API キーは削除されました。新しい API キーを設定してください。", + "modelProvider.selector.configure": "設定", "modelProvider.selector.configureRequired": "設定が必要です", "modelProvider.selector.creditsExhausted": "クレジットを使い切りました", "modelProvider.selector.creditsExhaustedTip": "AI クレジットを使い切りました。プランをアップグレードするか、API キーを追加してください。", + "modelProvider.selector.disabled": "無効", + "modelProvider.selector.discoverMoreInMarketplace": "マーケットプレイスでもっと探す", "modelProvider.selector.emptySetting": "設定に移動して構成してください", "modelProvider.selector.emptyTip": "利用可能なモデルはありません", "modelProvider.selector.fromMarketplace": "マーケットプレイスから", "modelProvider.selector.incompatible": "非互換", "modelProvider.selector.incompatibleTip": "このモデルは現在のバージョンでは利用できません。別の利用可能なモデルを選択してください。", + "modelProvider.selector.install": "インストール", + "modelProvider.selector.modelProviderSettings": "モデルプロバイダー設定", + "modelProvider.selector.noProviderConfigured": "モデルプロバイダーが設定されていません", + "modelProvider.selector.noProviderConfiguredDesc": "マーケットプレイスでインストールするか、設定でプロバイダーを設定してください。", "modelProvider.selector.onlyCompatibleModelsShown": "互換性のあるモデルのみが表示されます", "modelProvider.selector.rerankTip": "Rerank モデルを設定してください", "modelProvider.selector.tip": "このモデルは削除されました。別のモデルを追加するか、別のモデルを選択してください。", diff --git a/web/i18n/ja-JP/plugin.json b/web/i18n/ja-JP/plugin.json index c51a0e4117..38645124fb 100644 --- a/web/i18n/ja-JP/plugin.json +++ b/web/i18n/ja-JP/plugin.json @@ -3,6 +3,7 @@ "action.delete": "プラグインを削除する", "action.deleteContentLeft": "削除しますか", "action.deleteContentRight": "プラグイン?", + "action.deleteSuccess": "プラグインが正常に削除されました", "action.pluginInfo": "プラグイン情報", "action.usedInApps": "このプラグインは{{num}}のアプリで使用されています。", "allCategories": "すべてのカテゴリ", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "インストール", "detailPanel.operation.remove": "削除", "detailPanel.operation.update": "更新", + "detailPanel.operation.updateTooltip": "最新のモデルにアクセスするために更新してください。", "detailPanel.operation.viewDetail": "詳細を見る", "detailPanel.serviceOk": "サービスは正常です", "detailPanel.strategyNum": "{{num}} {{strategy}} が含まれています", @@ -231,12 +233,18 @@ "source.local": "ローカルパッケージファイル", "source.marketplace": "マーケットプレイス", "task.clearAll": "すべてクリア", + "task.errorMsg.github": "このプラグインは自動インストールできませんでした。\nGitHub からインストールしてください。", + "task.errorMsg.marketplace": "このプラグインは自動インストールできませんでした。\nマーケットプレイスからインストールしてください。", + "task.errorMsg.unknown": "このプラグインをインストールできませんでした。\nプラグインのソースを特定できませんでした。", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} プラグインのインストールに失敗しました。表示するにはクリックしてください。", + "task.installFromGithub": "GitHub からインストール", + "task.installFromMarketplace": "マーケットプレイスからインストール", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} プラグインのインストールに失敗しました", "task.installing": "プラグインをインストール中。", + "task.installingHint": "インストール中...数分かかる場合があります。", "task.installingWithError": "{{installingLength}}個のプラグインをインストール中、{{successLength}}件成功、{{errorLength}}件失敗", "task.installingWithSuccess": "{{installingLength}}個のプラグインをインストール中、{{successLength}}個成功しました。", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/ja-JP/workflow.json b/web/i18n/ja-JP/workflow.json index a2fac68bad..b9b60d6e73 100644 --- a/web/i18n/ja-JP/workflow.json +++ b/web/i18n/ja-JP/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "ワークフロー更新", "error.startNodeRequired": "{{operation}}前に開始ノードを追加してください", "errorMsg.authRequired": "認証が必要です", + "errorMsg.configureModel": "モデルを設定してください", "errorMsg.fieldRequired": "{{field}} は必須です", "errorMsg.fields.code": "コード", "errorMsg.fields.model": "モデル", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "ビジョン変数", "errorMsg.invalidJson": "{{field}} は無効な JSON です", "errorMsg.invalidVariable": "無効な変数です", + "errorMsg.modelPluginNotInstalled": "無効な変数です。この変数を有効にするにはモデルを設定してください。", "errorMsg.noValidTool": "{{field}} に利用可能なツールがありません", "errorMsg.rerankModelRequired": "Rerank モデルが設定されていません", "errorMsg.startNodeRequired": "{{operation}}前に開始ノードを追加してください", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "メモリウィンドウサイズ", "nodes.common.outputVars": "出力変数", "nodes.common.pluginNotInstalled": "プラグインがインストールされていません", + "nodes.common.pluginsNotInstalled": "{{count}} 個のプラグインがインストールされていません", "nodes.common.retry.maxRetries": "最大試行回数", "nodes.common.retry.ms": "ミリ秒", "nodes.common.retry.retries": "再試行回数:{{num}}", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "チャンク", "nodes.knowledgeBase.chunksInputTip": "知識ベースノードの入力変数はチャンクです。変数のタイプは、選択されたチャンク構造と一貫性のある特定のJSONスキーマを持つオブジェクトです。", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks変数は必須です", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API キーが利用できません", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "クレジットを使い切りました", + "nodes.knowledgeBase.embeddingModelIncompatible": "非互換", "nodes.knowledgeBase.embeddingModelIsInvalid": "埋め込みモデルが無効です", "nodes.knowledgeBase.embeddingModelIsRequired": "埋め込みモデルが必要です", + "nodes.knowledgeBase.embeddingModelNotConfigured": "埋め込みモデルが設定されていません", "nodes.knowledgeBase.indexMethodIsRequired": "インデックスメソッドが必要です", + "nodes.knowledgeBase.notConfigured": "未設定", "nodes.knowledgeBase.rerankingModelIsInvalid": "リランキングモデルは無効です", "nodes.knowledgeBase.rerankingModelIsRequired": "再ランキングモデルが必要です", "nodes.knowledgeBase.retrievalSettingIsRequired": "リトリーバル設定が必要です", @@ -1063,10 +1071,12 @@ "panel.change": "変更", "panel.changeBlock": "ノード変更", "panel.checklist": "チェックリスト", + "panel.checklistDescription": "公開前に以下の問題を解決してください", "panel.checklistResolved": "全てのチェックが完了しました", "panel.checklistTip": "公開前に全ての項目を確認してください", "panel.createdBy": "作成者", "panel.goTo": "移動", + "panel.goToFix": "修正する", "panel.helpLink": "ドキュメントを見る", "panel.maximize": "キャンバスを最大化する", "panel.minimize": "全画面を終了する", diff --git a/web/i18n/ko-KR/app-debug.json b/web/i18n/ko-KR/app-debug.json index 3cb295a3f7..1f9c1e2420 100644 --- a/web/i18n/ko-KR/app-debug.json +++ b/web/i18n/ko-KR/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "실행", "inputs.title": "디버그 및 미리보기", "inputs.userInputField": "사용자 입력 필드", + "manageModels": "모델 관리", "modelConfig.modeType.chat": "채팅", "modelConfig.modeType.completion": "완료", "modelConfig.model": "모델", "modelConfig.setTone": "응답 톤 설정", "modelConfig.title": "모델 및 매개변수", + "noModelProviderConfigured": "모델 공급자가 구성되지 않았습니다", + "noModelProviderConfiguredTip": "모델 공급자를 설치하거나 구성하여 시작하세요.", + "noModelSelected": "모델이 선택되지 않았습니다", + "noModelSelectedTip": "계속하려면 위에서 모델을 구성하세요.", "noResult": "출력이 여기에 표시됩니다.", "notSetAPIKey.description": "LLM 제공자 키가 설정되지 않았습니다. 디버깅하기 전에 설정해야 합니다.", "notSetAPIKey.settingBtn": "설정으로 이동", diff --git a/web/i18n/ko-KR/common.json b/web/i18n/ko-KR/common.json index 8fa351c517..e50a7c2428 100644 --- a/web/i18n/ko-KR/common.json +++ b/web/i18n/ko-KR/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "무단", "modelProvider.buyQuota": "할당량 구매", "modelProvider.callTimes": "호출 횟수", + "modelProvider.card.aiCreditsInUse": "AI 크레딧 사용 중", + "modelProvider.card.aiCreditsOption": "AI 크레딧", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API 키 설정 필요", + "modelProvider.card.apiKeyUnavailableFallback": "API Key를 사용할 수 없어 AI 크레딧을 사용 중입니다", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "API Key 설정을 확인하여 다시 전환하세요", "modelProvider.card.buyQuota": "Buy Quota", "modelProvider.card.callTimes": "호출 횟수", + "modelProvider.card.creditsExhaustedDescription": "플랜을 업그레이드하거나 API 키를 설정하세요", + "modelProvider.card.creditsExhaustedFallback": "AI 크레딧이 소진되어 API Key를 사용 중입니다", + "modelProvider.card.creditsExhaustedFallbackDescription": "AI 크레딧 우선 사용을 재개하려면 플랜을 업그레이드하세요.", + "modelProvider.card.creditsExhaustedMessage": "AI 크레딧이 소진되었습니다", "modelProvider.card.modelAPI": "{{modelName}} 모델이 API 키를 사용하고 있습니다.", "modelProvider.card.modelNotSupported": "{{modelName}} 모델이 설치되지 않았습니다.", "modelProvider.card.modelSupported": "{{modelName}} 모델이 이 할당량을 사용하고 있습니다.", + "modelProvider.card.noApiKeysDescription": "자체 모델 자격 증명을 사용하려면 API 키를 추가하세요.", + "modelProvider.card.noApiKeysFallback": "API Key가 없어 AI 크레딧을 사용 중입니다", + "modelProvider.card.noApiKeysTitle": "아직 API 키가 구성되지 않았습니다", + "modelProvider.card.noAvailableUsage": "사용 가능한 용량 없음", "modelProvider.card.onTrial": "트라이얼 중", "modelProvider.card.paid": "유료", "modelProvider.card.priorityUse": "우선 사용", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "API 키 제거", "modelProvider.card.tip": "메시지 크레딧은 {{modelNames}}의 모델을 지원합니다. 유료 할당량에 우선순위가 부여됩니다. 무료 할당량은 유료 할당량이 소진된 후 사용됩니다.", "modelProvider.card.tokens": "토큰", + "modelProvider.card.unavailable": "사용 불가", + "modelProvider.card.upgradePlan": "플랜 업그레이드", + "modelProvider.card.usageLabel": "사용량", + "modelProvider.card.usagePriority": "사용 우선순위", + "modelProvider.card.usagePriorityTip": "모델 실행 시 우선 사용할 리소스를 설정합니다.", "modelProvider.collapse": "축소", "modelProvider.config": "설정", "modelProvider.configLoadBalancing": "Config 로드 밸런싱", @@ -387,9 +406,11 @@ "modelProvider.model": "모델", "modelProvider.modelAndParameters": "모델 및 매개변수", "modelProvider.modelHasBeenDeprecated": "이 모델은 더 이상 사용되지 않습니다", + "modelProvider.modelSettings": "모델 설정", "modelProvider.models": "모델", "modelProvider.modelsNum": "{{num}}개의 모델", "modelProvider.noModelFound": "{{model}}에 대한 모델을 찾을 수 없습니다", + "modelProvider.noneConfigured": "애플리케이션을 실행하려면 기본 시스템 모델을 구성하세요", "modelProvider.notConfigured": "시스템 모델이 아직 완전히 설정되지 않아 일부 기능을 사용할 수 없습니다.", "modelProvider.parameters": "매개변수", "modelProvider.parametersInvalidRemoved": "일부 매개변수가 유효하지 않아 제거되었습니다.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "{{date}}에 재설정", "modelProvider.searchModel": "검색 모델", "modelProvider.selectModel": "모델 선택", + "modelProvider.selector.aiCredits": "AI 크레딧", + "modelProvider.selector.apiKeyUnavailable": "API Key 사용 불가", + "modelProvider.selector.apiKeyUnavailableTip": "API Key가 삭제되었습니다. 새 API Key를 설정하세요.", + "modelProvider.selector.configure": "구성", + "modelProvider.selector.configureRequired": "구성 필요", + "modelProvider.selector.creditsExhausted": "크레딧 소진", + "modelProvider.selector.creditsExhaustedTip": "AI 크레딧이 소진되었습니다. 플랜을 업그레이드하거나 API 키를 추가하세요.", + "modelProvider.selector.disabled": "비활성화됨", + "modelProvider.selector.discoverMoreInMarketplace": "마켓플레이스에서 더 찾아보기", "modelProvider.selector.emptySetting": "설정으로 이동하여 구성하세요", "modelProvider.selector.emptyTip": "사용 가능한 모델이 없습니다", + "modelProvider.selector.fromMarketplace": "마켓플레이스에서", + "modelProvider.selector.incompatible": "호환되지 않음", + "modelProvider.selector.incompatibleTip": "이 모델은 현재 버전에서 사용할 수 없습니다. 다른 사용 가능한 모델을 선택하세요.", + "modelProvider.selector.install": "설치", + "modelProvider.selector.modelProviderSettings": "모델 공급자 설정", + "modelProvider.selector.noProviderConfigured": "모델 공급자가 구성되지 않았습니다", + "modelProvider.selector.noProviderConfiguredDesc": "마켓플레이스에서 설치하거나 설정에서 공급자를 구성하세요.", + "modelProvider.selector.onlyCompatibleModelsShown": "호환되는 모델만 표시됩니다", "modelProvider.selector.rerankTip": "재랭크 모델을 설정하세요", "modelProvider.selector.tip": "이 모델은 삭제되었습니다. 다른 모델을 추가하거나 다른 모델을 선택하세요.", "modelProvider.setupModelFirst": "먼저 모델을 설정하세요", diff --git a/web/i18n/ko-KR/plugin.json b/web/i18n/ko-KR/plugin.json index 2f20a584c5..93b28beeb4 100644 --- a/web/i18n/ko-KR/plugin.json +++ b/web/i18n/ko-KR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "플러그인 제거", "action.deleteContentLeft": "제거하시겠습니까?", "action.deleteContentRight": "플러그인?", + "action.deleteSuccess": "플러그인이 성공적으로 제거되었습니다", "action.pluginInfo": "플러그인 정보", "action.usedInApps": "이 플러그인은 {{num}}개의 앱에서 사용되고 있습니다.", "allCategories": "모든 카테고리", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "설치", "detailPanel.operation.remove": "제거", "detailPanel.operation.update": "업데이트", + "detailPanel.operation.updateTooltip": "최신 모델에 액세스하려면 업데이트하세요.", "detailPanel.operation.viewDetail": "자세히보기", "detailPanel.serviceOk": "서비스 정상", "detailPanel.strategyNum": "{{num}} {{strategy}} 포함", @@ -231,12 +233,18 @@ "source.local": "로컬 패키지 파일", "source.marketplace": "마켓", "task.clearAll": "모두 지우기", + "task.errorMsg.github": "이 플러그인을 자동으로 설치할 수 없었습니다.\nGitHub에서 설치하세요.", + "task.errorMsg.marketplace": "이 플러그인을 자동으로 설치할 수 없었습니다.\n마켓플레이스에서 설치하세요.", + "task.errorMsg.unknown": "이 플러그인을 설치할 수 없었습니다.\n플러그인 소스를 확인할 수 없습니다.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} 플러그인 설치 실패, 보려면 클릭하십시오.", + "task.installFromGithub": "GitHub에서 설치", + "task.installFromMarketplace": "마켓플레이스에서 설치", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} 플러그인 설치 실패", "task.installing": "플러그인 설치 중.", + "task.installingHint": "설치 중... 몇 분 정도 걸릴 수 있습니다.", "task.installingWithError": "{{installingLength}} 플러그인 설치, {{successLength}} 성공, {{errorLength}} 실패", "task.installingWithSuccess": "{{installingLength}} 플러그인 설치, {{successLength}} 성공.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/ko-KR/workflow.json b/web/i18n/ko-KR/workflow.json index b83afea149..24e8e634d3 100644 --- a/web/i18n/ko-KR/workflow.json +++ b/web/i18n/ko-KR/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "워크플로 업데이트", "error.startNodeRequired": "{{operation}} 전에 먼저 시작 노드를 추가해 주세요", "errorMsg.authRequired": "인증이 필요합니다", + "errorMsg.configureModel": "모델을 구성하세요", "errorMsg.fieldRequired": "{{field}}가 필요합니다", "errorMsg.fields.code": "코드", "errorMsg.fields.model": "모델", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "비전 변수", "errorMsg.invalidJson": "{{field}}는 잘못된 JSON 입니다", "errorMsg.invalidVariable": "잘못된 변수", + "errorMsg.modelPluginNotInstalled": "잘못된 변수입니다. 이 변수를 사용하려면 모델을 구성하세요.", "errorMsg.noValidTool": "{{field}} 유효한 도구가 선택되지 않았습니다.", "errorMsg.rerankModelRequired": "Rerank Model 을 켜기 전에 설정에서 모델이 성공적으로 구성되었는지 확인하십시오.", "errorMsg.startNodeRequired": "{{operation}} 전에 먼저 시작 노드를 추가해 주세요", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "창 크기", "nodes.common.outputVars": "출력 변수", "nodes.common.pluginNotInstalled": "플러그인이 설치되지 않았습니다", + "nodes.common.pluginsNotInstalled": "{{count}}개의 플러그인이 설치되지 않았습니다", "nodes.common.retry.maxRetries": "최대 재시도 횟수", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} 재시도", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "청크", "nodes.knowledgeBase.chunksInputTip": "지식 기반 노드의 입력 변수는 Chunks입니다. 변수 유형은 선택된 청크 구조와 일치해야 하는 특정 JSON 스키마를 가진 객체입니다.", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks 변수는 필수입니다", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key 사용 불가", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "크레딧 소진", + "nodes.knowledgeBase.embeddingModelIncompatible": "호환되지 않음", "nodes.knowledgeBase.embeddingModelIsInvalid": "임베딩 모델이 유효하지 않습니다", "nodes.knowledgeBase.embeddingModelIsRequired": "임베딩 모델이 필요합니다", + "nodes.knowledgeBase.embeddingModelNotConfigured": "임베딩 모델이 구성되지 않았습니다", "nodes.knowledgeBase.indexMethodIsRequired": "인덱스 메서드가 필요합니다.", + "nodes.knowledgeBase.notConfigured": "구성되지 않음", "nodes.knowledgeBase.rerankingModelIsInvalid": "재정렬 모델이 유효하지 않습니다", "nodes.knowledgeBase.rerankingModelIsRequired": "재순위 모델이 필요합니다", "nodes.knowledgeBase.retrievalSettingIsRequired": "검색 설정이 필요합니다.", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Jinja2 만 지원합니다", "nodes.templateTransform.inputVars": "입력 변수", "nodes.templateTransform.outputVars.output": "변환된 내용", + "nodes.tool.authorizationRequired": "인증 필요", "nodes.tool.authorize": "권한 부여", "nodes.tool.inputVars": "입력 변수", "nodes.tool.insertPlaceholder1": "타이프하거나 누르세요", @@ -1062,10 +1071,12 @@ "panel.change": "변경", "panel.changeBlock": "노드 변경", "panel.checklist": "체크리스트", + "panel.checklistDescription": "게시 전에 다음 문제를 해결하세요", "panel.checklistResolved": "모든 문제가 해결되었습니다", "panel.checklistTip": "게시하기 전에 모든 문제가 해결되었는지 확인하세요", "panel.createdBy": "작성자 ", "panel.goTo": "로 이동", + "panel.goToFix": "수정하러 가기", "panel.helpLink": "도움말 센터", "panel.maximize": "캔버스 전체 화면", "panel.minimize": "전체 화면 종료", diff --git a/web/i18n/nl-NL/app-debug.json b/web/i18n/nl-NL/app-debug.json index b667cfb052..7552cb7cc9 100644 --- a/web/i18n/nl-NL/app-debug.json +++ b/web/i18n/nl-NL/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "RUN", "inputs.title": "Debug & Preview", "inputs.userInputField": "User Input Field", + "manageModels": "Modellen beheren", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Complete", "modelConfig.model": "Model", "modelConfig.setTone": "Set tone of responses", "modelConfig.title": "Model and Parameters", + "noModelProviderConfigured": "Geen modelprovider geconfigureerd", + "noModelProviderConfiguredTip": "Installeer of configureer een modelprovider om te beginnen.", + "noModelSelected": "Geen model geselecteerd", + "noModelSelectedTip": "Configureer hierboven een model om door te gaan.", "noResult": "Output will be displayed here.", "notSetAPIKey.description": "The LLM provider key has not been set, and it needs to be set before debugging.", "notSetAPIKey.settingBtn": "Go to settings", diff --git a/web/i18n/nl-NL/common.json b/web/i18n/nl-NL/common.json index 013f06b035..fb1b332a0c 100644 --- a/web/i18n/nl-NL/common.json +++ b/web/i18n/nl-NL/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Unauthorized", "modelProvider.buyQuota": "Buy Quota", "modelProvider.callTimes": "Call times", + "modelProvider.card.aiCreditsInUse": "AI-tegoeden in gebruik", + "modelProvider.card.aiCreditsOption": "AI-tegoeden", + "modelProvider.card.apiKeyOption": "API-sleutel", + "modelProvider.card.apiKeyRequired": "API-sleutel vereist", + "modelProvider.card.apiKeyUnavailableFallback": "API-sleutel niet beschikbaar, nu worden AI-tegoeden gebruikt", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Controleer uw API-sleutelconfiguratie om terug te schakelen", "modelProvider.card.buyQuota": "Buy Quota", "modelProvider.card.callTimes": "Call times", + "modelProvider.card.creditsExhaustedDescription": "Verhaag uw abonnement of configureer een API-sleutel", + "modelProvider.card.creditsExhaustedFallback": "AI-tegoeden uitgeput, nu wordt API-sleutel gebruikt", + "modelProvider.card.creditsExhaustedFallbackDescription": "Verhoog uw abonnement om AI-tegoed prioriteit te hervatten.", + "modelProvider.card.creditsExhaustedMessage": "AI-tegoeden zijn uitgeput", "modelProvider.card.modelAPI": "{{modelName}} models are using the API Key.", - "modelProvider.card.modelNotSupported": "{{modelName}} models are not installed.", - "modelProvider.card.modelSupported": "{{modelName}} models are using this quota.", + "modelProvider.card.modelNotSupported": "{{modelName}} niet geïnstalleerd", + "modelProvider.card.modelSupported": "{{modelName}} modellen gebruiken deze tegoeden.", + "modelProvider.card.noApiKeysDescription": "Voeg een API-sleutel toe om uw eigen modelgegevens te gebruiken.", + "modelProvider.card.noApiKeysFallback": "Geen API-sleutels, AI-tegoeden worden gebruikt", + "modelProvider.card.noApiKeysTitle": "Nog geen API-sleutels geconfigureerd", + "modelProvider.card.noAvailableUsage": "Geen beschikbaar gebruik", "modelProvider.card.onTrial": "On Trial", "modelProvider.card.paid": "Paid", "modelProvider.card.priorityUse": "Priority use", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Remove API Key", "modelProvider.card.tip": "Message Credits supports models from {{modelNames}}. Priority will be given to the paid quota. The free quota will be used after the paid quota is exhausted.", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "Niet beschikbaar", + "modelProvider.card.upgradePlan": "verhaag uw abonnement", + "modelProvider.card.usageLabel": "Gebruik", + "modelProvider.card.usagePriority": "Gebruiksprioriteit", + "modelProvider.card.usagePriorityTip": "Stel in welke resource als eerste wordt gebruikt bij het uitvoeren van modellen.", "modelProvider.collapse": "Collapse", "modelProvider.config": "Config", "modelProvider.configLoadBalancing": "Config Load Balancing", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model and Parameters", "modelProvider.modelHasBeenDeprecated": "This model has been deprecated", + "modelProvider.modelSettings": "Modelinstellingen", "modelProvider.models": "Models", "modelProvider.modelsNum": "{{num}} Models", "modelProvider.noModelFound": "No model found for {{model}}", + "modelProvider.noneConfigured": "Configureer een standaard systeemmodel om applicaties uit te voeren", "modelProvider.notConfigured": "The system model has not yet been fully configured", "modelProvider.parameters": "PARAMETERS", "modelProvider.parametersInvalidRemoved": "Some parameters are invalid and have been removed", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Reset on {{date}}", "modelProvider.searchModel": "Search model", "modelProvider.selectModel": "Select your model", + "modelProvider.selector.aiCredits": "AI-tegoeden", + "modelProvider.selector.apiKeyUnavailable": "API-sleutel niet beschikbaar", + "modelProvider.selector.apiKeyUnavailableTip": "De API-sleutel is verwijderd. Configureer een nieuwe API-sleutel.", + "modelProvider.selector.configure": "Configureren", + "modelProvider.selector.configureRequired": "Configuratie vereist", + "modelProvider.selector.creditsExhausted": "Tegoeden uitgeput", + "modelProvider.selector.creditsExhaustedTip": "Uw AI-tegoeden zijn uitgeput. Verhoog uw abonnement of voeg een API-sleutel toe.", + "modelProvider.selector.disabled": "Uitgeschakeld", + "modelProvider.selector.discoverMoreInMarketplace": "Ontdek meer in de Marketplace", "modelProvider.selector.emptySetting": "Please go to settings to configure", "modelProvider.selector.emptyTip": "No available models", + "modelProvider.selector.fromMarketplace": "Vanuit de Marketplace", + "modelProvider.selector.incompatible": "Incompatibel", + "modelProvider.selector.incompatibleTip": "Dit model is niet beschikbaar in de huidige versie. Selecteer een ander beschikbaar model.", + "modelProvider.selector.install": "Installeren", + "modelProvider.selector.modelProviderSettings": "Modelproviderinstellingen", + "modelProvider.selector.noProviderConfigured": "Geen modelprovider geconfigureerd", + "modelProvider.selector.noProviderConfiguredDesc": "Blader in de Marketplace om er een te installeren, of configureer providers in de instellingen.", + "modelProvider.selector.onlyCompatibleModelsShown": "Alleen compatibele modellen worden weergegeven", "modelProvider.selector.rerankTip": "Please set up the Rerank model", "modelProvider.selector.tip": "This model has been removed. Please add a model or select another model.", "modelProvider.setupModelFirst": "Please set up your model first", diff --git a/web/i18n/nl-NL/plugin.json b/web/i18n/nl-NL/plugin.json index c7f091a442..4e0301b5f5 100644 --- a/web/i18n/nl-NL/plugin.json +++ b/web/i18n/nl-NL/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Remove plugin", "action.deleteContentLeft": "Would you like to remove ", "action.deleteContentRight": " plugin?", + "action.deleteSuccess": "Plugin succesvol verwijderd", "action.pluginInfo": "Plugin info", "action.usedInApps": "This plugin is being used in {{num}} apps.", "allCategories": "All Categories", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Install", "detailPanel.operation.remove": "Remove", "detailPanel.operation.update": "Update", + "detailPanel.operation.updateTooltip": "Werk bij voor toegang tot de nieuwste modellen.", "detailPanel.operation.viewDetail": "View Detail", "detailPanel.serviceOk": "Service OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUDED", @@ -231,12 +233,18 @@ "source.local": "Local Package File", "source.marketplace": "Marketplace", "task.clearAll": "Clear all", + "task.errorMsg.github": "Deze plugin is niet automatisch geïnstalleerd.\nInstalleer het vanuit GitHub.", + "task.errorMsg.marketplace": "Deze plugin is niet automatisch geïnstalleerd.\nInstalleer het vanuit de Marketplace.", + "task.errorMsg.unknown": "Deze plugin is niet geïnstalleerd.\nDe pluginbron kon niet worden geïdentificeerd.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} plugins failed to install, click to view", + "task.installFromGithub": "Installeren vanuit GitHub", + "task.installFromMarketplace": "Installeren vanuit de Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} plugins failed to install", "task.installing": "Installing plugins", + "task.installingHint": "Installeren... Dit kan enkele minuten duren.", "task.installingWithError": "Installing {{installingLength}} plugins, {{successLength}} success, {{errorLength}} failed", "task.installingWithSuccess": "Installing {{installingLength}} plugins, {{successLength}} success.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/nl-NL/workflow.json b/web/i18n/nl-NL/workflow.json index 35555d0e7b..891df72387 100644 --- a/web/i18n/nl-NL/workflow.json +++ b/web/i18n/nl-NL/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "updating workflow", "error.startNodeRequired": "Please add a start node first before {{operation}}", "errorMsg.authRequired": "Authorization is required", + "errorMsg.configureModel": "Configureer een model", "errorMsg.fieldRequired": "{{field}} is required", "errorMsg.fields.code": "Code", "errorMsg.fields.model": "Model", @@ -313,7 +314,8 @@ "errorMsg.fields.variableValue": "Variable Value", "errorMsg.fields.visionVariable": "Vision Variable", "errorMsg.invalidJson": "{{field}} is invalid JSON", - "errorMsg.invalidVariable": "Invalid variable", + "errorMsg.invalidVariable": "Ongeldige variabele. Selecteer een bestaande variabele.", + "errorMsg.modelPluginNotInstalled": "Ongeldige variabele. Configureer een model om deze variabele in te schakelen.", "errorMsg.noValidTool": "{{field}} no valid tool selected", "errorMsg.rerankModelRequired": "A configured Rerank Model is required", "errorMsg.startNodeRequired": "Please add a start node first before {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Window Size", "nodes.common.outputVars": "Output Variables", "nodes.common.pluginNotInstalled": "Plugin is not installed", + "nodes.common.pluginsNotInstalled": "{{count}} plugins niet geïnstalleerd", "nodes.common.retry.maxRetries": "max retries", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Retries", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Chunks", "nodes.knowledgeBase.chunksInputTip": "The input variable of the knowledge base node is Chunks. The variable type is an object with a specific JSON Schema which must be consistent with the selected chunk structure.", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks variable is required", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API-sleutel niet beschikbaar", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Tegoeden uitgeput", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatibel", "nodes.knowledgeBase.embeddingModelIsInvalid": "Embedding model is invalid", "nodes.knowledgeBase.embeddingModelIsRequired": "Embedding model is required", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Inbeddingsmodel niet geconfigureerd", "nodes.knowledgeBase.indexMethodIsRequired": "Index method is required", + "nodes.knowledgeBase.notConfigured": "Niet geconfigureerd", "nodes.knowledgeBase.rerankingModelIsInvalid": "Reranking model is invalid", "nodes.knowledgeBase.rerankingModelIsRequired": "Reranking model is required", "nodes.knowledgeBase.retrievalSettingIsRequired": "Retrieval setting is required", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Only supports Jinja2", "nodes.templateTransform.inputVars": "Input Variables", "nodes.templateTransform.outputVars.output": "Transformed content", + "nodes.tool.authorizationRequired": "Autorisatie vereist", "nodes.tool.authorize": "Authorize", "nodes.tool.inputVars": "Input Variables", "nodes.tool.insertPlaceholder1": "Type or press", @@ -1062,10 +1071,12 @@ "panel.change": "Change", "panel.changeBlock": "Change Node", "panel.checklist": "Checklist", + "panel.checklistDescription": "Los de volgende problemen op vóór het publiceren", "panel.checklistResolved": "All issues are resolved", "panel.checklistTip": "Make sure all issues are resolved before publishing", "panel.createdBy": "Created By ", "panel.goTo": "Go to", + "panel.goToFix": "Ga naar repareren", "panel.helpLink": "View Docs", "panel.maximize": "Maximize Canvas", "panel.minimize": "Exit Full Screen", diff --git a/web/i18n/pl-PL/app-debug.json b/web/i18n/pl-PL/app-debug.json index 4d7c2bf2e5..d708c997e9 100644 --- a/web/i18n/pl-PL/app-debug.json +++ b/web/i18n/pl-PL/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "URUCHOM", "inputs.title": "Debugowanie i podgląd", "inputs.userInputField": "Pole wejściowe użytkownika", + "manageModels": "Zarządzaj modelami", "modelConfig.modeType.chat": "Czat", "modelConfig.modeType.completion": "Uzupełnienie", "modelConfig.model": "Model", "modelConfig.setTone": "Ustaw ton odpowiedzi", "modelConfig.title": "Model i parametry", + "noModelProviderConfigured": "Nie skonfigurowano dostawcy modeli", + "noModelProviderConfiguredTip": "Zainstaluj lub skonfiguruj dostawcę modeli, aby rozpocząć.", + "noModelSelected": "Nie wybrano modelu", + "noModelSelectedTip": "skonfiguruj model powyżej, aby kontynuować.", "noResult": "W tym miejscu zostaną wyświetlone dane wyjściowe.", "notSetAPIKey.description": "Klucz dostawcy LLM nie został ustawiony, musi zostać ustawiony przed debugowaniem.", "notSetAPIKey.settingBtn": "Przejdź do ustawień", diff --git a/web/i18n/pl-PL/common.json b/web/i18n/pl-PL/common.json index 029e9dd660..130950a57c 100644 --- a/web/i18n/pl-PL/common.json +++ b/web/i18n/pl-PL/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Nieautoryzowany", "modelProvider.buyQuota": "Kup limit", "modelProvider.callTimes": "Czasy wywołań", + "modelProvider.card.aiCreditsInUse": "Używane AI credits", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Wymagany klucz API", + "modelProvider.card.apiKeyUnavailableFallback": "API Key niedostępny, używane są AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Sprawdź konfigurację klucza API, aby przełączyć z powrotem", "modelProvider.card.buyQuota": "Kup limit", "modelProvider.card.callTimes": "Czasy wywołań", + "modelProvider.card.creditsExhaustedDescription": "Proszę ulepszyć swój plan lub skonfigurować klucz API", + "modelProvider.card.creditsExhaustedFallback": "AI credits wyczerpane, używany jest API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Ulepsz swój plan, aby przywrócić priorytet AI credits.", + "modelProvider.card.creditsExhaustedMessage": "AI credits zostały wyczerpane", "modelProvider.card.modelAPI": "Modele {{modelName}} używają klucza API.", "modelProvider.card.modelNotSupported": "Modele {{modelName}} nie są zainstalowane.", "modelProvider.card.modelSupported": "Modele {{modelName}} używają tego limitu.", + "modelProvider.card.noApiKeysDescription": "Dodaj klucz API, aby zacząć korzystać z własnych poświadczeń modelu.", + "modelProvider.card.noApiKeysFallback": "Brak kluczy API, używane są AI credits", + "modelProvider.card.noApiKeysTitle": "Nie skonfigurowano jeszcze kluczy API", + "modelProvider.card.noAvailableUsage": "Brak dostępnego użycia", "modelProvider.card.onTrial": "Na próbę", "modelProvider.card.paid": "Płatny", "modelProvider.card.priorityUse": "Używanie z priorytetem", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Usuń klucz API", "modelProvider.card.tip": "Kredyty wiadomości obsługują modele od {{modelNames}}. Priorytet zostanie nadany płatnemu limitowi. Darmowy limit zostanie użyty po wyczerpaniu płatnego limitu.", "modelProvider.card.tokens": "Tokeny", + "modelProvider.card.unavailable": "Niedostępne", + "modelProvider.card.upgradePlan": "ulepsz swój plan", + "modelProvider.card.usageLabel": "Użycie", + "modelProvider.card.usagePriority": "Priorytet użycia", + "modelProvider.card.usagePriorityTip": "Ustaw, który zasób ma być używany jako pierwszy przy uruchamianiu modeli.", "modelProvider.collapse": "Zwiń", "modelProvider.config": "Konfiguracja", "modelProvider.configLoadBalancing": "Równoważenie obciążenia konfiguracji", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model i parametry", "modelProvider.modelHasBeenDeprecated": "Ten model jest przestarzały", + "modelProvider.modelSettings": "Ustawienia modelu", "modelProvider.models": "Modele", "modelProvider.modelsNum": "{{num}} Modele", "modelProvider.noModelFound": "Nie znaleziono modelu dla {{model}}", + "modelProvider.noneConfigured": "Skonfiguruj domyślny model systemowy, aby uruchamiać aplikacje", "modelProvider.notConfigured": "Systemowy model nie został jeszcze w pełni skonfigurowany, co może skutkować niedostępnością niektórych funkcji.", "modelProvider.parameters": "PARAMETRY", "modelProvider.parametersInvalidRemoved": "Niektóre parametry są nieprawidłowe i zostały usunięte.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Reset {{date}}", "modelProvider.searchModel": "Model wyszukiwania", "modelProvider.selectModel": "Wybierz swój model", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key niedostępny", + "modelProvider.selector.apiKeyUnavailableTip": "Klucz API został usunięty. Proszę skonfigurować nowy klucz API.", + "modelProvider.selector.configure": "Konfiguruj", + "modelProvider.selector.configureRequired": "Wymagana konfiguracja", + "modelProvider.selector.creditsExhausted": "Kredyty wyczerpane", + "modelProvider.selector.creditsExhaustedTip": "Twoje AI credits zostały wyczerpane. Proszę ulepszyć swój plan lub dodać klucz API.", + "modelProvider.selector.disabled": "Wyłączony", + "modelProvider.selector.discoverMoreInMarketplace": "Odkryj więcej w Marketplace", "modelProvider.selector.emptySetting": "Przejdź do ustawień, aby skonfigurować", "modelProvider.selector.emptyTip": "Brak dostępnych modeli", + "modelProvider.selector.fromMarketplace": "Z Marketplace", + "modelProvider.selector.incompatible": "Niekompatybilny", + "modelProvider.selector.incompatibleTip": "Ten model nie jest dostępny w bieżącej wersji. Proszę wybrać inny dostępny model.", + "modelProvider.selector.install": "Zainstaluj", + "modelProvider.selector.modelProviderSettings": "Ustawienia dostawcy modeli", + "modelProvider.selector.noProviderConfigured": "Nie skonfigurowano dostawcy modeli", + "modelProvider.selector.noProviderConfiguredDesc": "Przeglądaj Marketplace, aby zainstalować dostawcę, lub skonfiguruj dostawców w ustawieniach.", + "modelProvider.selector.onlyCompatibleModelsShown": "Wyświetlane są tylko kompatybilne modele", "modelProvider.selector.rerankTip": "Proszę skonfigurować model ponownego rankingu", "modelProvider.selector.tip": "Ten model został usunięty. Proszę dodać model lub wybrać inny model.", "modelProvider.setupModelFirst": "Proszę najpierw skonfigurować swój model", diff --git a/web/i18n/pl-PL/plugin.json b/web/i18n/pl-PL/plugin.json index abe94610dc..88ac549968 100644 --- a/web/i18n/pl-PL/plugin.json +++ b/web/i18n/pl-PL/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Usuń wtyczkę", "action.deleteContentLeft": "Czy chcesz usunąć", "action.deleteContentRight": "wtyczka?", + "action.deleteSuccess": "Wtyczka została pomyślnie usunięta", "action.pluginInfo": "Informacje o wtyczce", "action.usedInApps": "Ta wtyczka jest używana w aplikacjach {{num}}.", "allCategories": "Wszystkie kategorie", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Instalować", "detailPanel.operation.remove": "Usunąć", "detailPanel.operation.update": "Aktualizacja", + "detailPanel.operation.updateTooltip": "Zaktualizuj, aby uzyskać dostęp do najnowszych modeli.", "detailPanel.operation.viewDetail": "Pokaż szczegóły", "detailPanel.serviceOk": "Serwis OK", "detailPanel.strategyNum": "{{num}} {{strategy}} ZAWARTE", @@ -231,12 +233,18 @@ "source.local": "Lokalny plik pakietu", "source.marketplace": "Rynek", "task.clearAll": "Wyczyść wszystko", + "task.errorMsg.github": "Nie udało się automatycznie zainstalować tej wtyczki.\nProszę zainstalować ją z GitHub.", + "task.errorMsg.marketplace": "Nie udało się automatycznie zainstalować tej wtyczki.\nProszę zainstalować ją z Marketplace.", + "task.errorMsg.unknown": "Nie udało się zainstalować tej wtyczki.\nNie można zidentyfikować źródła wtyczki.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Nie udało się zainstalować wtyczek {{errorLength}}, kliknij, aby wyświetlić", + "task.installFromGithub": "Zainstaluj z GitHub", + "task.installFromMarketplace": "Zainstaluj z Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Nie udało się zainstalować wtyczek {{errorLength}}", "task.installing": "Instalowanie wtyczek.", + "task.installingHint": "Instalowanie... To może potrwać kilka minut.", "task.installingWithError": "Instalacja wtyczek {{installingLength}}, {{successLength}} powodzenie, {{errorLength}} niepowodzenie", "task.installingWithSuccess": "Instalacja wtyczek {{installingLength}}, {{successLength}} powodzenie.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/pl-PL/workflow.json b/web/i18n/pl-PL/workflow.json index 52514fe6b7..8aad8e0b71 100644 --- a/web/i18n/pl-PL/workflow.json +++ b/web/i18n/pl-PL/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "aktualizowanie przepływu pracy", "error.startNodeRequired": "Najpierw dodaj węzeł początkowy przed {{operation}}", "errorMsg.authRequired": "Wymagana autoryzacja", + "errorMsg.configureModel": "Skonfiguruj model", "errorMsg.fieldRequired": "{{field}} jest wymagane", "errorMsg.fields.code": "Kod", "errorMsg.fields.model": "Model", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Zmienna wizji", "errorMsg.invalidJson": "{{field}} jest nieprawidłowym JSON-em", "errorMsg.invalidVariable": "Nieprawidłowa zmienna", + "errorMsg.modelPluginNotInstalled": "Nieprawidłowa zmienna. Skonfiguruj model, aby włączyć tę zmienną.", "errorMsg.noValidTool": "{{field}} nie wybrano prawidłowego narzędzia", "errorMsg.rerankModelRequired": "Przed włączeniem Rerank Model upewnij się, że model został pomyślnie skonfigurowany w ustawieniach.", "errorMsg.startNodeRequired": "Najpierw dodaj węzeł początkowy przed {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Rozmiar okna", "nodes.common.outputVars": "Zmienne wyjściowe", "nodes.common.pluginNotInstalled": "Wtyczka nie jest zainstalowana", + "nodes.common.pluginsNotInstalled": "{{count}} wtyczek niezainstalowanych", "nodes.common.retry.maxRetries": "Maksymalna liczba ponownych prób", "nodes.common.retry.ms": "Ms", "nodes.common.retry.retries": "{{num}} Ponownych prób", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Kawałki", "nodes.knowledgeBase.chunksInputTip": "Zmienna wejściowa węzła bazy wiedzy to Chunks. Typ zmiennej to obiekt z określonym schematem JSON, który musi być zgodny z wybraną strukturą chunk.", "nodes.knowledgeBase.chunksVariableIsRequired": "Wymagana jest zmienna Chunks", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "Klucz API niedostępny", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Kredyty wyczerpane", + "nodes.knowledgeBase.embeddingModelIncompatible": "Niekompatybilny", "nodes.knowledgeBase.embeddingModelIsInvalid": "Model osadzania jest nieprawidłowy", "nodes.knowledgeBase.embeddingModelIsRequired": "Wymagany jest model osadzania", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Model embeddingu nie jest skonfigurowany", "nodes.knowledgeBase.indexMethodIsRequired": "Metoda indeksowa jest wymagana", + "nodes.knowledgeBase.notConfigured": "Nieskonfigurowany", "nodes.knowledgeBase.rerankingModelIsInvalid": "Model ponownego rankingowania jest nieprawidłowy", "nodes.knowledgeBase.rerankingModelIsRequired": "Wymagany jest model ponownego rankingu", "nodes.knowledgeBase.retrievalSettingIsRequired": "Wymagane jest ustawienie pobierania", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Obsługuje tylko Jinja2", "nodes.templateTransform.inputVars": "Zmienne wejściowe", "nodes.templateTransform.outputVars.output": "Przekształcona treść", + "nodes.tool.authorizationRequired": "Wymagana autoryzacja", "nodes.tool.authorize": "Autoryzuj", "nodes.tool.inputVars": "Zmienne wejściowe", "nodes.tool.insertPlaceholder1": "Wpisz lub naciśnij", @@ -1062,10 +1071,12 @@ "panel.change": "Zmień", "panel.changeBlock": "Zmień węzeł", "panel.checklist": "Lista kontrolna", + "panel.checklistDescription": "Rozwiąż poniższe problemy przed opublikowaniem", "panel.checklistResolved": "Wszystkie problemy zostały rozwiązane", "panel.checklistTip": "Upewnij się, że wszystkie problemy zostały rozwiązane przed opublikowaniem", "panel.createdBy": "Stworzone przez ", "panel.goTo": "Idź do", + "panel.goToFix": "Przejdź do naprawy", "panel.helpLink": "Pomoc", "panel.maximize": "Maksymalizuj płótno", "panel.minimize": "Wyjdź z trybu pełnoekranowego", diff --git a/web/i18n/pt-BR/app-debug.json b/web/i18n/pt-BR/app-debug.json index 5b0b4c9969..3b1a0da424 100644 --- a/web/i18n/pt-BR/app-debug.json +++ b/web/i18n/pt-BR/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "EXECUTAR", "inputs.title": "Depuração e Visualização", "inputs.userInputField": "Campo de Entrada do Usuário", + "manageModels": "Gerenciar modelos", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Completar", "modelConfig.model": "Modelo", "modelConfig.setTone": "Definir tom das respostas", "modelConfig.title": "Modelo e Parâmetros", + "noModelProviderConfigured": "Nenhum provedor de modelo configurado", + "noModelProviderConfiguredTip": "Instale ou configure um provedor de modelo para começar.", + "noModelSelected": "Nenhum modelo selecionado", + "noModelSelectedTip": "Configure um modelo acima para continuar.", "noResult": "A saída será exibida aqui.", "notSetAPIKey.description": "A chave do provedor LLM não foi definida e precisa ser definida antes da depuração.", "notSetAPIKey.settingBtn": "Ir para configurações", diff --git a/web/i18n/pt-BR/common.json b/web/i18n/pt-BR/common.json index 0f2106b1eb..6840bb964b 100644 --- a/web/i18n/pt-BR/common.json +++ b/web/i18n/pt-BR/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Não autorizado", "modelProvider.buyQuota": "Comprar Quota", "modelProvider.callTimes": "Chamadas", + "modelProvider.card.aiCreditsInUse": "AI credits em uso", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API Key necessária", + "modelProvider.card.apiKeyUnavailableFallback": "API Key indisponível, usando AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Verifique a configuração da sua API Key para voltar a usá-la", "modelProvider.card.buyQuota": "Comprar Quota", "modelProvider.card.callTimes": "Chamadas", + "modelProvider.card.creditsExhaustedDescription": "Por favor, atualize seu plano ou configure uma API Key", + "modelProvider.card.creditsExhaustedFallback": "AI credits esgotados, usando API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Atualize seu plano para retomar a prioridade de AI credits.", + "modelProvider.card.creditsExhaustedMessage": "AI credits foram esgotados", "modelProvider.card.modelAPI": "Os modelos {{modelName}} estão usando a Chave API.", "modelProvider.card.modelNotSupported": "Os modelos {{modelName}} não estão instalados.", "modelProvider.card.modelSupported": "Os modelos {{modelName}} estão usando esta cota.", + "modelProvider.card.noApiKeysDescription": "Adicione uma API Key para usar suas próprias credenciais de modelo.", + "modelProvider.card.noApiKeysFallback": "Sem API Keys, usando AI credits", + "modelProvider.card.noApiKeysTitle": "Nenhuma API Key configurada", + "modelProvider.card.noAvailableUsage": "Sem uso disponível", "modelProvider.card.onTrial": "Em Teste", "modelProvider.card.paid": "Pago", "modelProvider.card.priorityUse": "Uso prioritário", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Remover Chave da API", "modelProvider.card.tip": "Créditos de mensagens suportam modelos de {{modelNames}}. A prioridade será dada à quota paga. A quota gratuita será usada após a quota paga ser esgotada.", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "Indisponível", + "modelProvider.card.upgradePlan": "atualize seu plano", + "modelProvider.card.usageLabel": "Uso", + "modelProvider.card.usagePriority": "Prioridade de uso", + "modelProvider.card.usagePriorityTip": "Defina qual recurso usar primeiro ao executar modelos.", "modelProvider.collapse": "Recolher", "modelProvider.config": "Configuração", "modelProvider.configLoadBalancing": "Balanceamento de carga de configuração", @@ -387,9 +406,11 @@ "modelProvider.model": "Modelo", "modelProvider.modelAndParameters": "Modelo e Parâmetros", "modelProvider.modelHasBeenDeprecated": "Este modelo foi preterido", + "modelProvider.modelSettings": "Configurações de modelo", "modelProvider.models": "Modelos", "modelProvider.modelsNum": "{{num}} Modelos", "modelProvider.noModelFound": "Nenhum modelo encontrado para {{model}}", + "modelProvider.noneConfigured": "Configure um modelo de sistema padrão para executar aplicações", "modelProvider.notConfigured": "O modelo do sistema ainda não foi totalmente configurado e algumas funções podem estar indisponíveis.", "modelProvider.parameters": "PARÂMETROS", "modelProvider.parametersInvalidRemoved": "Alguns parâmetros são inválidos e foram removidos", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Redefinir em {{date}}", "modelProvider.searchModel": "Modelo de pesquisa", "modelProvider.selectModel": "Selecione seu modelo", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key indisponível", + "modelProvider.selector.apiKeyUnavailableTip": "A API Key foi removida. Por favor, configure uma nova API Key.", + "modelProvider.selector.configure": "Configurar", + "modelProvider.selector.configureRequired": "Configuração necessária", + "modelProvider.selector.creditsExhausted": "Créditos esgotados", + "modelProvider.selector.creditsExhaustedTip": "Seus AI credits foram esgotados. Por favor, atualize seu plano ou adicione uma API Key.", + "modelProvider.selector.disabled": "Desativado", + "modelProvider.selector.discoverMoreInMarketplace": "Descubra mais no Marketplace", "modelProvider.selector.emptySetting": "Por favor, vá para configurações para configurar", "modelProvider.selector.emptyTip": "Nenhum modelo disponível", + "modelProvider.selector.fromMarketplace": "Do Marketplace", + "modelProvider.selector.incompatible": "Incompatível", + "modelProvider.selector.incompatibleTip": "Este modelo não está disponível na versão atual. Por favor, selecione outro modelo disponível.", + "modelProvider.selector.install": "Instalar", + "modelProvider.selector.modelProviderSettings": "Configurações do provedor de modelo", + "modelProvider.selector.noProviderConfigured": "Nenhum provedor de modelo configurado", + "modelProvider.selector.noProviderConfiguredDesc": "Navegue pelo Marketplace para instalar um, ou configure provedores nas configurações.", + "modelProvider.selector.onlyCompatibleModelsShown": "Apenas modelos compatíveis são exibidos", "modelProvider.selector.rerankTip": "Por favor, configure o modelo de reordenação", "modelProvider.selector.tip": "Este modelo foi removido. Adicione um modelo ou selecione outro modelo.", "modelProvider.setupModelFirst": "Por favor, configure seu modelo primeiro", diff --git a/web/i18n/pt-BR/plugin.json b/web/i18n/pt-BR/plugin.json index c9297b5840..d5d218c8b6 100644 --- a/web/i18n/pt-BR/plugin.json +++ b/web/i18n/pt-BR/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Remover plugin", "action.deleteContentLeft": "Gostaria de remover", "action.deleteContentRight": "plugin?", + "action.deleteSuccess": "Plugin removido com sucesso", "action.pluginInfo": "Informações do plugin", "action.usedInApps": "Este plugin está sendo usado em aplicativos {{num}}.", "allCategories": "Todas as categorias", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Instalar", "detailPanel.operation.remove": "Retirar", "detailPanel.operation.update": "Atualização", + "detailPanel.operation.updateTooltip": "Atualize para acessar os modelos mais recentes.", "detailPanel.operation.viewDetail": "Ver detalhes", "detailPanel.serviceOk": "Serviço OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUSO", @@ -231,12 +233,18 @@ "source.local": "Arquivo de pacote local", "source.marketplace": "Mercado", "task.clearAll": "Apagar tudo", + "task.errorMsg.github": "Este plugin não pôde ser instalado automaticamente.\nPor favor, instale-o pelo GitHub.", + "task.errorMsg.marketplace": "Este plugin não pôde ser instalado automaticamente.\nPor favor, instale-o pelo Marketplace.", + "task.errorMsg.unknown": "Este plugin não pôde ser instalado.\nNão foi possível identificar a origem do plugin.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} plugins falha ao instalar, clique para ver", + "task.installFromGithub": "Instalar pelo GitHub", + "task.installFromMarketplace": "Instalar pelo Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Falha na instalação dos plug-ins {{errorLength}}", "task.installing": "Instalando plugins.", + "task.installingHint": "Instalando... Isso pode levar alguns minutos.", "task.installingWithError": "Instalando plug-ins {{installingLength}}, {{successLength}} sucesso, {{errorLength}} falhou", "task.installingWithSuccess": "Instalando plugins {{installingLength}}, {{successLength}} sucesso.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/pt-BR/workflow.json b/web/i18n/pt-BR/workflow.json index 3edab361bb..dc986df82c 100644 --- a/web/i18n/pt-BR/workflow.json +++ b/web/i18n/pt-BR/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "atualizando fluxo de trabalho", "error.startNodeRequired": "Por favor, adicione um nó inicial antes de {{operation}}", "errorMsg.authRequired": "Autorização é necessária", + "errorMsg.configureModel": "Configure um modelo", "errorMsg.fieldRequired": "{{field}} é obrigatório", "errorMsg.fields.code": "Código", "errorMsg.fields.model": "Modelo", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Variável de visão", "errorMsg.invalidJson": "{{field}} é um JSON inválido", "errorMsg.invalidVariable": "Variável inválida", + "errorMsg.modelPluginNotInstalled": "Variável inválida. Configure um modelo para habilitar esta variável.", "errorMsg.noValidTool": "{{field}} nenhuma ferramenta válida selecionada", "errorMsg.rerankModelRequired": "Antes de ativar o modelo de reclassificação, confirme se o modelo foi configurado com sucesso nas configurações.", "errorMsg.startNodeRequired": "Por favor, adicione um nó inicial antes de {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Tamanho da janela", "nodes.common.outputVars": "Variáveis de saída", "nodes.common.pluginNotInstalled": "O plugin não está instalado", + "nodes.common.pluginsNotInstalled": "{{count}} plugins não instalados", "nodes.common.retry.maxRetries": "Máximo de tentativas", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Tentativas", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Pedaços", "nodes.knowledgeBase.chunksInputTip": "A variável de entrada do nó da base de conhecimento é Chunks. O tipo da variável é um objeto com um esquema JSON específico que deve ser consistente com a estrutura de chunk selecionada.", "nodes.knowledgeBase.chunksVariableIsRequired": "A variável 'chunks' é obrigatória", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key indisponível", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Créditos esgotados", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatível", "nodes.knowledgeBase.embeddingModelIsInvalid": "O modelo de incorporação é inválido", "nodes.knowledgeBase.embeddingModelIsRequired": "Modelo de incorporação é necessário", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modelo de embedding não configurado", "nodes.knowledgeBase.indexMethodIsRequired": "O método de índice é necessário", + "nodes.knowledgeBase.notConfigured": "Não configurado", "nodes.knowledgeBase.rerankingModelIsInvalid": "O modelo de reclassificação é inválido", "nodes.knowledgeBase.rerankingModelIsRequired": "Um modelo de reclassificação é necessário", "nodes.knowledgeBase.retrievalSettingIsRequired": "A configuração de recuperação é necessária", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Suporta apenas Jinja2", "nodes.templateTransform.inputVars": "Variáveis de entrada", "nodes.templateTransform.outputVars.output": "Conteúdo transformado", + "nodes.tool.authorizationRequired": "Autorização necessária", "nodes.tool.authorize": "Autorizar", "nodes.tool.inputVars": "Variáveis de entrada", "nodes.tool.insertPlaceholder1": "Digite ou pressione", @@ -1062,10 +1071,12 @@ "panel.change": "Mudar", "panel.changeBlock": "Mudar Nó", "panel.checklist": "Lista de verificação", + "panel.checklistDescription": "Resolva os seguintes problemas antes de publicar", "panel.checklistResolved": "Todos os problemas foram resolvidos", "panel.checklistTip": "Certifique-se de que todos os problemas foram resolvidos antes de publicar", "panel.createdBy": "Criado por ", "panel.goTo": "Ir para", + "panel.goToFix": "Ir para correção", "panel.helpLink": "Ajuda", "panel.maximize": "Maximize Canvas", "panel.minimize": "Sair do Modo Tela Cheia", diff --git a/web/i18n/ro-RO/app-debug.json b/web/i18n/ro-RO/app-debug.json index 6245ca680d..d2b8294804 100644 --- a/web/i18n/ro-RO/app-debug.json +++ b/web/i18n/ro-RO/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "RULARE", "inputs.title": "Depanare și previzualizare", "inputs.userInputField": "Câmp de intrare utilizator", + "manageModels": "Gestionare modele", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Completare", "modelConfig.model": "Model", "modelConfig.setTone": "Setați tonul răspunsurilor", "modelConfig.title": "Model și Parametri", + "noModelProviderConfigured": "Niciun furnizor de modele configurat", + "noModelProviderConfiguredTip": "Instalați sau configurați un furnizor de modele pentru a începe.", + "noModelSelected": "Niciun model selectat", + "noModelSelectedTip": "configurați un model mai sus pentru a continua.", "noResult": "Ieșirea va fi afișată aici.", "notSetAPIKey.description": "Cheia furnizorului LLM nu a fost setată și trebuie să fie setată înainte de depanare.", "notSetAPIKey.settingBtn": "Du-te la setări", diff --git a/web/i18n/ro-RO/common.json b/web/i18n/ro-RO/common.json index 8fdb09b3c1..306439768b 100644 --- a/web/i18n/ro-RO/common.json +++ b/web/i18n/ro-RO/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Neautorizat", "modelProvider.buyQuota": "Cumpără cotă", "modelProvider.callTimes": "Apeluri", + "modelProvider.card.aiCreditsInUse": "AI credits în uz", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "API key necesar", + "modelProvider.card.apiKeyUnavailableFallback": "API Key indisponibil, se utilizează AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Verificați configurația API key pentru a reveni", "modelProvider.card.buyQuota": "Cumpără cotă", "modelProvider.card.callTimes": "Apeluri", + "modelProvider.card.creditsExhaustedDescription": "Vă rugăm să faceți upgrade la plan sau să configurați un API key", + "modelProvider.card.creditsExhaustedFallback": "AI credits epuizate, se utilizează API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Faceți upgrade la plan pentru a relua prioritatea AI credits.", + "modelProvider.card.creditsExhaustedMessage": "AI credits au fost epuizate", "modelProvider.card.modelAPI": "Modelele {{modelName}} folosesc cheia API.", "modelProvider.card.modelNotSupported": "Modelele {{modelName}} nu sunt instalate.", "modelProvider.card.modelSupported": "Modelele {{modelName}} folosesc această cotă.", + "modelProvider.card.noApiKeysDescription": "Adăugați un API key pentru a începe să utilizați propriile credențiale de model.", + "modelProvider.card.noApiKeysFallback": "Fără API key-uri, se utilizează AI credits", + "modelProvider.card.noApiKeysTitle": "Niciun API key configurat încă", + "modelProvider.card.noAvailableUsage": "Nicio utilizare disponibilă", "modelProvider.card.onTrial": "În probă", "modelProvider.card.paid": "Plătit", "modelProvider.card.priorityUse": "Utilizare prioritară", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Elimină cheia API", "modelProvider.card.tip": "Creditele de mesaje acceptă modele de la {{modelNames}}. Prioritate va fi acordată cotei plătite. Cota gratuită va fi utilizată după epuizarea cotei plătite.", "modelProvider.card.tokens": "Jetoane", + "modelProvider.card.unavailable": "Indisponibil", + "modelProvider.card.upgradePlan": "faceți upgrade la plan", + "modelProvider.card.usageLabel": "Utilizare", + "modelProvider.card.usagePriority": "Prioritate de utilizare", + "modelProvider.card.usagePriorityTip": "Setați ce resursă să fie utilizată prima la rularea modelelor.", "modelProvider.collapse": "Restrânge", "modelProvider.config": "Configurare", "modelProvider.configLoadBalancing": "Echilibrarea încărcării de configurare", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model și parametri", "modelProvider.modelHasBeenDeprecated": "Acest model a fost depreciat", + "modelProvider.modelSettings": "Setări model", "modelProvider.models": "Modele", "modelProvider.modelsNum": "{{num}} Modele", "modelProvider.noModelFound": "Nu a fost găsit niciun model pentru {{model}}", + "modelProvider.noneConfigured": "Configurați un model de sistem implicit pentru a rula aplicații", "modelProvider.notConfigured": "Modelul de sistem nu a fost încă configurat complet, iar unele funcții pot fi indisponibile.", "modelProvider.parameters": "PARAMETRI", "modelProvider.parametersInvalidRemoved": "Unele parametrii sunt invalizi și au fost eliminați.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Resetare la {{date}}", "modelProvider.searchModel": "Model de căutare", "modelProvider.selectModel": "Selectați modelul dvs.", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key indisponibil", + "modelProvider.selector.apiKeyUnavailableTip": "API key-ul a fost eliminat. Vă rugăm să configurați un nou API key.", + "modelProvider.selector.configure": "Configurare", + "modelProvider.selector.configureRequired": "Configurare necesară", + "modelProvider.selector.creditsExhausted": "Credite epuizate", + "modelProvider.selector.creditsExhaustedTip": "AI credits au fost epuizate. Vă rugăm să faceți upgrade la plan sau să adăugați un API key.", + "modelProvider.selector.disabled": "Dezactivat", + "modelProvider.selector.discoverMoreInMarketplace": "Descoperiți mai multe în Marketplace", "modelProvider.selector.emptySetting": "Vă rugăm să mergeți la setări pentru a configura", "modelProvider.selector.emptyTip": "Nu există modele disponibile", + "modelProvider.selector.fromMarketplace": "Din Marketplace", + "modelProvider.selector.incompatible": "Incompatibil", + "modelProvider.selector.incompatibleTip": "Acest model nu este disponibil în versiunea curentă. Vă rugăm să selectați alt model disponibil.", + "modelProvider.selector.install": "Instalare", + "modelProvider.selector.modelProviderSettings": "Setări furnizor de modele", + "modelProvider.selector.noProviderConfigured": "Niciun furnizor de modele configurat", + "modelProvider.selector.noProviderConfiguredDesc": "Răsfoiți Marketplace pentru a instala unul sau configurați furnizorii în setări.", + "modelProvider.selector.onlyCompatibleModelsShown": "Sunt afișate doar modelele compatibile", "modelProvider.selector.rerankTip": "Vă rugăm să configurați modelul de reordonare", "modelProvider.selector.tip": "Acest model a fost eliminat. Vă rugăm să adăugați un model sau să selectați un alt model.", "modelProvider.setupModelFirst": "Vă rugăm să configurați mai întâi modelul", diff --git a/web/i18n/ro-RO/plugin.json b/web/i18n/ro-RO/plugin.json index c0295b64d6..15523b6c62 100644 --- a/web/i18n/ro-RO/plugin.json +++ b/web/i18n/ro-RO/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Eliminați pluginul", "action.deleteContentLeft": "Doriți să eliminați", "action.deleteContentRight": "plugin?", + "action.deleteSuccess": "Plugin eliminat cu succes", "action.pluginInfo": "Informații despre plugin", "action.usedInApps": "Acest plugin este folosit în aplicațiile {{num}}.", "allCategories": "Toate categoriile", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Instala", "detailPanel.operation.remove": "Depărta", "detailPanel.operation.update": "Actualiza", + "detailPanel.operation.updateTooltip": "Actualizați pentru a accesa cele mai recente modele.", "detailPanel.operation.viewDetail": "Vezi detalii", "detailPanel.serviceOk": "Serviciu OK", "detailPanel.strategyNum": "{{num}} {{strategy}} INCLUS", @@ -231,12 +233,18 @@ "source.local": "Fișier pachet local", "source.marketplace": "Târg", "task.clearAll": "Ștergeți tot", + "task.errorMsg.github": "Acest plugin nu a putut fi instalat automat.\nVă rugăm să îl instalați de pe GitHub.", + "task.errorMsg.marketplace": "Acest plugin nu a putut fi instalat automat.\nVă rugăm să îl instalați din Marketplace.", + "task.errorMsg.unknown": "Acest plugin nu a putut fi instalat.\nSursa pluginului nu a putut fi identificată.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} plugin-urile nu s-au instalat, faceți clic pentru a vizualiza", + "task.installFromGithub": "Instalare de pe GitHub", + "task.installFromMarketplace": "Instalare din Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} plugin-urile nu s-au instalat", "task.installing": "Se instalează pluginuri.", + "task.installingHint": "Se instalează... Aceasta poate dura câteva minute.", "task.installingWithError": "Instalarea pluginurilor {{installingLength}}, {{successLength}} succes, {{errorLength}} eșuat", "task.installingWithSuccess": "Instalarea pluginurilor {{installingLength}}, {{successLength}} succes.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/ro-RO/workflow.json b/web/i18n/ro-RO/workflow.json index a68d477980..e34ba42007 100644 --- a/web/i18n/ro-RO/workflow.json +++ b/web/i18n/ro-RO/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "actualizarea fluxului de lucru", "error.startNodeRequired": "Vă rugăm să adăugați mai întâi un nod de pornire înainte de {{operation}}", "errorMsg.authRequired": "Autorizarea este necesară", + "errorMsg.configureModel": "Configurați un model", "errorMsg.fieldRequired": "{{field}} este obligatoriu", "errorMsg.fields.code": "Cod", "errorMsg.fields.model": "Model", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vizibilitate variabilă", "errorMsg.invalidJson": "{{field}} este un JSON invalid", "errorMsg.invalidVariable": "Variabilă invalidă", + "errorMsg.modelPluginNotInstalled": "Variabilă invalidă. Configurați un model pentru a activa această variabilă.", "errorMsg.noValidTool": "{{field}} nu a fost selectat niciun instrument valid", "errorMsg.rerankModelRequired": "Înainte de a activa modelul de reclasificare, vă rugăm să confirmați că modelul a fost configurat cu succes în setări.", "errorMsg.startNodeRequired": "Vă rugăm să adăugați mai întâi un nod de pornire înainte de {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Dimensiunea ferestrei", "nodes.common.outputVars": "Variabile de ieșire", "nodes.common.pluginNotInstalled": "Pluginul nu este instalat", + "nodes.common.pluginsNotInstalled": "{{count}} pluginuri neinstalate", "nodes.common.retry.maxRetries": "numărul maxim de încercări", "nodes.common.retry.ms": "Ms", "nodes.common.retry.retries": "{{num}} Încercări", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Bucăți", "nodes.knowledgeBase.chunksInputTip": "Variabila de intrare a nodului bazei de cunoștințe este Chunks. Tipul variabilei este un obiect cu un Șchema JSON specific care trebuie să fie coerent cu structura de chunk selectată.", "nodes.knowledgeBase.chunksVariableIsRequired": "Variabila Chunks este obligatorie", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API key indisponibil", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Credite epuizate", + "nodes.knowledgeBase.embeddingModelIncompatible": "Incompatibil", "nodes.knowledgeBase.embeddingModelIsInvalid": "Modelul de încorporare este invalid", "nodes.knowledgeBase.embeddingModelIsRequired": "Este necesar un model de încorporare", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Modelul de embedding nu este configurat", "nodes.knowledgeBase.indexMethodIsRequired": "Este necesară metoda indexului", + "nodes.knowledgeBase.notConfigured": "Neconfigurat", "nodes.knowledgeBase.rerankingModelIsInvalid": "Modelul de reordonare este invalid", "nodes.knowledgeBase.rerankingModelIsRequired": "Este necesar un model de reordonare", "nodes.knowledgeBase.retrievalSettingIsRequired": "Setarea de recuperare este necesară", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Suportă doar Jinja2", "nodes.templateTransform.inputVars": "Variabile de intrare", "nodes.templateTransform.outputVars.output": "Conținut transformat", + "nodes.tool.authorizationRequired": "Autorizare necesară", "nodes.tool.authorize": "Autorizați", "nodes.tool.inputVars": "Variabile de intrare", "nodes.tool.insertPlaceholder1": "Scrieți sau apăsați", @@ -1062,10 +1071,12 @@ "panel.change": "Schimbă", "panel.changeBlock": "Schimbă nodul", "panel.checklist": "Lista de verificare", + "panel.checklistDescription": "Rezolvați următoarele probleme înainte de publicare", "panel.checklistResolved": "Toate problemele au fost rezolvate", "panel.checklistTip": "Asigurați-vă că toate problemele sunt rezolvate înainte de publicare", "panel.createdBy": "Creat de ", "panel.goTo": "Du-te la", + "panel.goToFix": "Mergi la remediere", "panel.helpLink": "Ajutor", "panel.maximize": "Maximize Canvas", "panel.minimize": "Iesi din modul pe tot ecranul", diff --git a/web/i18n/ru-RU/app-debug.json b/web/i18n/ru-RU/app-debug.json index fdc797da2f..514958101d 100644 --- a/web/i18n/ru-RU/app-debug.json +++ b/web/i18n/ru-RU/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "ЗАПУСТИТЬ", "inputs.title": "Отладка и предварительный просмотр", "inputs.userInputField": "Поле пользовательского ввода", + "manageModels": "Управление моделями", "modelConfig.modeType.chat": "Чат", "modelConfig.modeType.completion": "Завершение", "modelConfig.model": "Модель", "modelConfig.setTone": "Установить тон ответов", "modelConfig.title": "Модель и параметры", + "noModelProviderConfigured": "Поставщик моделей не настроен", + "noModelProviderConfiguredTip": "Установите или настройте поставщика моделей, чтобы начать работу.", + "noModelSelected": "Модель не выбрана", + "noModelSelectedTip": "Настройте модель выше, чтобы продолжить.", "noResult": "Вывод будет отображаться здесь.", "notSetAPIKey.description": "Ключ поставщика LLM не установлен, его необходимо установить перед отладкой.", "notSetAPIKey.settingBtn": "Перейти к настройкам", diff --git a/web/i18n/ru-RU/common.json b/web/i18n/ru-RU/common.json index dc301e4ed9..aec9a69483 100644 --- a/web/i18n/ru-RU/common.json +++ b/web/i18n/ru-RU/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Неавторизованный", "modelProvider.buyQuota": "Купить квоту", "modelProvider.callTimes": "Количество вызовов", + "modelProvider.card.aiCreditsInUse": "Используются AI-кредиты", + "modelProvider.card.aiCreditsOption": "AI-кредиты", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Требуется API-ключ", + "modelProvider.card.apiKeyUnavailableFallback": "API Key недоступен, используются AI-кредиты", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Проверьте настройки API Key для переключения обратно", "modelProvider.card.buyQuota": "Купить квоту", "modelProvider.card.callTimes": "Количество вызовов", + "modelProvider.card.creditsExhaustedDescription": "Пожалуйста, обновите тарифный план или настройте API-ключ", + "modelProvider.card.creditsExhaustedFallback": "AI-кредиты исчерпаны, используется API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Обновите тарифный план, чтобы возобновить приоритетное использование AI-кредитов.", + "modelProvider.card.creditsExhaustedMessage": "AI-кредиты исчерпаны", "modelProvider.card.modelAPI": "Модели {{modelName}} используют API-ключ.", "modelProvider.card.modelNotSupported": "Модели {{modelName}} не установлены.", "modelProvider.card.modelSupported": "Модели {{modelName}} используют эту квоту.", + "modelProvider.card.noApiKeysDescription": "Добавьте API-ключ, чтобы использовать собственные учётные данные модели.", + "modelProvider.card.noApiKeysFallback": "API-ключи не настроены, используются AI-кредиты", + "modelProvider.card.noApiKeysTitle": "API-ключи ещё не настроены", + "modelProvider.card.noAvailableUsage": "Нет доступного использования", "modelProvider.card.onTrial": "Пробная версия", "modelProvider.card.paid": "Платный", "modelProvider.card.priorityUse": "Приоритетное использование", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Удалить API-ключ", "modelProvider.card.tip": "Кредиты сообщений поддерживают модели от {{modelNames}}. Приоритет будет отдаваться платной квоте. Бесплатная квота будет использоваться после исчерпания платной квоты.", "modelProvider.card.tokens": "Токены", + "modelProvider.card.unavailable": "Недоступно", + "modelProvider.card.upgradePlan": "обновите тарифный план", + "modelProvider.card.usageLabel": "Использование", + "modelProvider.card.usagePriority": "Приоритет использования", + "modelProvider.card.usagePriorityTip": "Укажите, какой ресурс использовать первым при запуске моделей.", "modelProvider.collapse": "Свернуть", "modelProvider.config": "Настройка", "modelProvider.configLoadBalancing": "Настроить балансировку нагрузки", @@ -387,9 +406,11 @@ "modelProvider.model": "Модель", "modelProvider.modelAndParameters": "Модель и параметры", "modelProvider.modelHasBeenDeprecated": "Эта модель устарела", + "modelProvider.modelSettings": "Настройки моделей", "modelProvider.models": "Модели", "modelProvider.modelsNum": "{{num}} Моделей", "modelProvider.noModelFound": "Модель не найдена для {{model}}", + "modelProvider.noneConfigured": "Настройте системную модель по умолчанию для запуска приложений", "modelProvider.notConfigured": "Системная модель еще не полностью настроена, и некоторые функции могут быть недоступны.", "modelProvider.parameters": "ПАРАМЕТРЫ", "modelProvider.parametersInvalidRemoved": "Некоторые параметры недействительны и были удалены", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Сброс {{date}}", "modelProvider.searchModel": "Поиск модели", "modelProvider.selectModel": "Выберите свою модель", + "modelProvider.selector.aiCredits": "AI-кредиты", + "modelProvider.selector.apiKeyUnavailable": "API Key недоступен", + "modelProvider.selector.apiKeyUnavailableTip": "API Key был удалён. Пожалуйста, настройте новый API Key.", + "modelProvider.selector.configure": "Настроить", + "modelProvider.selector.configureRequired": "Требуется настройка", + "modelProvider.selector.creditsExhausted": "Кредиты исчерпаны", + "modelProvider.selector.creditsExhaustedTip": "Ваши AI-кредиты исчерпаны. Обновите тарифный план или добавьте API-ключ.", + "modelProvider.selector.disabled": "Отключено", + "modelProvider.selector.discoverMoreInMarketplace": "Найти больше в Marketplace", "modelProvider.selector.emptySetting": "Пожалуйста, перейдите в настройки для настройки", "modelProvider.selector.emptyTip": "Нет доступных моделей", + "modelProvider.selector.fromMarketplace": "Из Marketplace", + "modelProvider.selector.incompatible": "Несовместимо", + "modelProvider.selector.incompatibleTip": "Эта модель недоступна в текущей версии. Пожалуйста, выберите другую доступную модель.", + "modelProvider.selector.install": "Установить", + "modelProvider.selector.modelProviderSettings": "Настройки поставщика моделей", + "modelProvider.selector.noProviderConfigured": "Поставщик моделей не настроен", + "modelProvider.selector.noProviderConfiguredDesc": "Установите из Marketplace или настройте поставщиков в настройках.", + "modelProvider.selector.onlyCompatibleModelsShown": "Показаны только совместимые модели", "modelProvider.selector.rerankTip": "Пожалуйста, настройте модель повторного ранжирования", "modelProvider.selector.tip": "Эта модель была удалена. Пожалуйста, добавьте модель или выберите другую модель.", "modelProvider.setupModelFirst": "Пожалуйста, сначала настройте свою модель", diff --git a/web/i18n/ru-RU/plugin.json b/web/i18n/ru-RU/plugin.json index 3b10a3a995..52cb6cbd55 100644 --- a/web/i18n/ru-RU/plugin.json +++ b/web/i18n/ru-RU/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Удалить плагин", "action.deleteContentLeft": "Вы хотели бы удалить", "action.deleteContentRight": "Плагин?", + "action.deleteSuccess": "Плагин успешно удалён", "action.pluginInfo": "Информация о плагине", "action.usedInApps": "Этот плагин используется в приложениях {{num}}.", "allCategories": "Все категории", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Устанавливать", "detailPanel.operation.remove": "Убирать", "detailPanel.operation.update": "Обновлять", + "detailPanel.operation.updateTooltip": "Обновите для доступа к последним моделям.", "detailPanel.operation.viewDetail": "Подробнее", "detailPanel.serviceOk": "Услуга ОК", "detailPanel.strategyNum": "{{num}} {{strategy}} ВКЛЮЧЕННЫЙ", @@ -231,12 +233,18 @@ "source.local": "Локальный файл пакета", "source.marketplace": "Рынок", "task.clearAll": "Очистить все", + "task.errorMsg.github": "Не удалось автоматически установить этот плагин.\nПожалуйста, установите его из GitHub.", + "task.errorMsg.marketplace": "Не удалось автоматически установить этот плагин.\nПожалуйста, установите его из Marketplace.", + "task.errorMsg.unknown": "Не удалось установить этот плагин.\nНе удалось определить источник плагина.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Плагины {{errorLength}} не удалось установить, нажмите для просмотра", + "task.installFromGithub": "Установить из GitHub", + "task.installFromMarketplace": "Установить из Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "плагины {{errorLength}} не удалось установить", "task.installing": "Установка плагинов.", + "task.installingHint": "Установка... Это может занять несколько минут.", "task.installingWithError": "Установка плагинов {{installingLength}}, {{successLength}} успех, {{errorLength}} неудачный", "task.installingWithSuccess": "Установка плагинов {{installingLength}}, {{successLength}} успех.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/ru-RU/workflow.json b/web/i18n/ru-RU/workflow.json index 9265e4e6d6..73cdad253a 100644 --- a/web/i18n/ru-RU/workflow.json +++ b/web/i18n/ru-RU/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "обновление рабочего процесса", "error.startNodeRequired": "Пожалуйста, сначала добавьте начальный узел перед {{operation}}", "errorMsg.authRequired": "Требуется авторизация", + "errorMsg.configureModel": "Настройте модель", "errorMsg.fieldRequired": "{{field}} обязательно для заполнения", "errorMsg.fields.code": "Код", "errorMsg.fields.model": "Модель", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Переменная зрения", "errorMsg.invalidJson": "{{field}} неверный JSON", "errorMsg.invalidVariable": "Неверная переменная", + "errorMsg.modelPluginNotInstalled": "Недопустимая переменная. Настройте модель, чтобы активировать эту переменную.", "errorMsg.noValidTool": "{{field}} не выбран валидный инструмент", "errorMsg.rerankModelRequired": "Перед включением модели повторного ранжирования убедитесь, что модель успешно настроена в настройках.", "errorMsg.startNodeRequired": "Пожалуйста, сначала добавьте начальный узел перед {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Размер окна", "nodes.common.outputVars": "Выходные переменные", "nodes.common.pluginNotInstalled": "Плагин не установлен", + "nodes.common.pluginsNotInstalled": "{{count}} плагинов не установлено", "nodes.common.retry.maxRetries": "максимальное количество повторных попыток", "nodes.common.retry.ms": "госпожа", "nodes.common.retry.retries": "{{num}} Повторных попыток", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Куски", "nodes.knowledgeBase.chunksInputTip": "Входная переменная узла базы знаний - это Чанки. Тип переменной является объектом с определенной схемой JSON, которая должна соответствовать выбранной структуре чанка.", "nodes.knowledgeBase.chunksVariableIsRequired": "Переменная chunks обязательна", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API-ключ недоступен", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Кредиты исчерпаны", + "nodes.knowledgeBase.embeddingModelIncompatible": "Несовместимо", "nodes.knowledgeBase.embeddingModelIsInvalid": "Модель встраивания недействительна", "nodes.knowledgeBase.embeddingModelIsRequired": "Требуется модель встраивания", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Модель эмбеддингов не настроена", "nodes.knowledgeBase.indexMethodIsRequired": "Метод index является обязательным", + "nodes.knowledgeBase.notConfigured": "Не настроено", "nodes.knowledgeBase.rerankingModelIsInvalid": "Модель повторной ранжировки недействительна", "nodes.knowledgeBase.rerankingModelIsRequired": "Требуется модель перераспределения рангов", "nodes.knowledgeBase.retrievalSettingIsRequired": "Настройка извлечения обязательна", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Поддерживает только Jinja2", "nodes.templateTransform.inputVars": "Входные переменные", "nodes.templateTransform.outputVars.output": "Преобразованный контент", + "nodes.tool.authorizationRequired": "Требуется авторизация", "nodes.tool.authorize": "Авторизовать", "nodes.tool.inputVars": "Входные переменные", "nodes.tool.insertPlaceholder1": "Наберите или нажмите", @@ -1062,10 +1071,12 @@ "panel.change": "Изменить", "panel.changeBlock": "Изменить узел", "panel.checklist": "Контрольный список", + "panel.checklistDescription": "Устраните следующие проблемы перед публикацией", "panel.checklistResolved": "Все проблемы решены", "panel.checklistTip": "Убедитесь, что все проблемы решены перед публикацией", "panel.createdBy": "Создано ", "panel.goTo": "Перейти к", + "panel.goToFix": "Перейти к исправлению", "panel.helpLink": "Помощь", "panel.maximize": "Максимизировать холст", "panel.minimize": "Выйти из полноэкранного режима", diff --git a/web/i18n/sl-SI/app-debug.json b/web/i18n/sl-SI/app-debug.json index 948e3dbc67..cf3d6c9ec4 100644 --- a/web/i18n/sl-SI/app-debug.json +++ b/web/i18n/sl-SI/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "TEČI", "inputs.title": "Odpravljanje napak in predogled", "inputs.userInputField": "Uporabniško polje za vnos", + "manageModels": "Upravljanje modelov", "modelConfig.modeType.chat": "Chat", "modelConfig.modeType.completion": "Dokončati", "modelConfig.model": "Model", "modelConfig.setTone": "Nastavitev tona odzivov", "modelConfig.title": "Model in parametri", + "noModelProviderConfigured": "Ponudnik modelov ni konfiguriran", + "noModelProviderConfiguredTip": "Za začetek namestite ali konfigurirajte ponudnika modelov.", + "noModelSelected": "Noben model ni izbran", + "noModelSelectedTip": "Za nadaljevanje konfigurirajte model zgoraj.", "noResult": "Tukaj bo prikazan izhod.", "notSetAPIKey.description": "Ključ ponudnika LLM ni nastavljen. Pred odpravljanjem napak je treba nastaviti ključ.", "notSetAPIKey.settingBtn": "Pojdi v nastavitve", diff --git a/web/i18n/sl-SI/common.json b/web/i18n/sl-SI/common.json index 247874bb30..6ec4fe430c 100644 --- a/web/i18n/sl-SI/common.json +++ b/web/i18n/sl-SI/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Neavtorizirano", "modelProvider.buyQuota": "Kupi kvoto", "modelProvider.callTimes": "Število klicev", + "modelProvider.card.aiCreditsInUse": "Krediti AI v uporabi", + "modelProvider.card.aiCreditsOption": "Krediti AI", + "modelProvider.card.apiKeyOption": "API ključ", + "modelProvider.card.apiKeyRequired": "Potreben je API ključ", + "modelProvider.card.apiKeyUnavailableFallback": "API ključ ni na voljo, zdaj se uporabljajo krediti AI", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Preverite konfiguracijo API ključa za preklop nazaj", "modelProvider.card.buyQuota": "Kupi kvoto", "modelProvider.card.callTimes": "Časi klicev", + "modelProvider.card.creditsExhaustedDescription": "Prosimo, nadgradite načrt ali konfigurirajte API ključ", + "modelProvider.card.creditsExhaustedFallback": "Krediti AI so porabljeni, zdaj se uporablja API ključ", + "modelProvider.card.creditsExhaustedFallbackDescription": "Nadgradite načrt za nadaljevanje prednostne uporabe kreditov AI.", + "modelProvider.card.creditsExhaustedMessage": "Krediti AI so bili porabljeni", "modelProvider.card.modelAPI": "Modeli {{modelName}} uporabljajo API ključ.", "modelProvider.card.modelNotSupported": "Modeli {{modelName}} niso nameščeni.", "modelProvider.card.modelSupported": "Modeli {{modelName}} uporabljajo to kvoto.", + "modelProvider.card.noApiKeysDescription": "Dodajte API ključ za začetek uporabe lastnih poverilnic modela.", + "modelProvider.card.noApiKeysFallback": "Ni API ključev, namesto tega se uporabljajo krediti AI", + "modelProvider.card.noApiKeysTitle": "Še ni konfiguriranih API ključev", + "modelProvider.card.noAvailableUsage": "Ni razpoložljive uporabe", "modelProvider.card.onTrial": "Na preizkusu", "modelProvider.card.paid": "Plačano", "modelProvider.card.priorityUse": "Prednostna uporaba", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Odstrani API ključ", "modelProvider.card.tip": "Krediti za sporočila podpirajo modele od {{modelNames}}. Prednostno se bo uporabila plačana kvota. Brezplačna kvota se bo uporabila, ko bo plačana kvota porabljena.", "modelProvider.card.tokens": "Žetoni", + "modelProvider.card.unavailable": "Ni na voljo", + "modelProvider.card.upgradePlan": "nadgradite načrt", + "modelProvider.card.usageLabel": "Poraba", + "modelProvider.card.usagePriority": "Prednost porabe", + "modelProvider.card.usagePriorityTip": "Nastavite, kateri vir se bo najprej uporabil pri zaganjanju modelov.", "modelProvider.collapse": "Strni", "modelProvider.config": "Konfiguracija", "modelProvider.configLoadBalancing": "Konfiguracija uravnoteženja obremenitev", @@ -387,9 +406,11 @@ "modelProvider.model": "Model", "modelProvider.modelAndParameters": "Model in parametri", "modelProvider.modelHasBeenDeprecated": "Ta model je zastarel", + "modelProvider.modelSettings": "Nastavitve modela", "modelProvider.models": "Modeli", "modelProvider.modelsNum": "{{num}} modelov", "modelProvider.noModelFound": "Za {{model}} ni najden noben model", + "modelProvider.noneConfigured": "Konfigurirajte privzeti sistemski model za zaganjanje aplikacij", "modelProvider.notConfigured": "Sistemski model še ni popolnoma konfiguriran, nekatere funkcije morda ne bodo na voljo.", "modelProvider.parameters": "PARAMETRI", "modelProvider.parametersInvalidRemoved": "Nekateri parametri so neveljavni in so bili odstranjeni.", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Ponastavi {{date}}", "modelProvider.searchModel": "Model iskanja", "modelProvider.selectModel": "Izberite svoj model", + "modelProvider.selector.aiCredits": "Krediti AI", + "modelProvider.selector.apiKeyUnavailable": "API ključ ni na voljo", + "modelProvider.selector.apiKeyUnavailableTip": "API ključ je bil odstranjen. Prosimo, konfigurirajte nov API ključ.", + "modelProvider.selector.configure": "Konfiguriraj", + "modelProvider.selector.configureRequired": "Konfiguracija je obvezna", + "modelProvider.selector.creditsExhausted": "Krediti so porabljeni", + "modelProvider.selector.creditsExhaustedTip": "Vaši krediti AI so bili porabljeni. Prosimo, nadgradite načrt ali dodajte API ključ.", + "modelProvider.selector.disabled": "Onemogočeno", + "modelProvider.selector.discoverMoreInMarketplace": "Odkrijte več v Tržnici", "modelProvider.selector.emptySetting": "Prosimo, pojdite v nastavitve za konfiguracijo", "modelProvider.selector.emptyTip": "Ni razpoložljivih modelov", + "modelProvider.selector.fromMarketplace": "Iz Tržnice", + "modelProvider.selector.incompatible": "Nezdružljivo", + "modelProvider.selector.incompatibleTip": "Ta model ni na voljo v trenutni različici. Izberite drug razpoložljiv model.", + "modelProvider.selector.install": "Namesti", + "modelProvider.selector.modelProviderSettings": "Nastavitve ponudnika modelov", + "modelProvider.selector.noProviderConfigured": "Noben ponudnik modelov ni konfiguriran", + "modelProvider.selector.noProviderConfiguredDesc": "Poiščite v Tržnici za namestitev ali konfigurirajte ponudnike v nastavitvah.", + "modelProvider.selector.onlyCompatibleModelsShown": "Prikazani so samo združljivi modeli", "modelProvider.selector.rerankTip": "Prosimo, nastavite model za prerazvrstitev", "modelProvider.selector.tip": "Ta model je bil odstranjen. Prosimo, dodajte model ali izberite drugega.", "modelProvider.setupModelFirst": "Najprej nastavite svoj model", diff --git a/web/i18n/sl-SI/plugin.json b/web/i18n/sl-SI/plugin.json index 0f84e91c80..91f6c670bf 100644 --- a/web/i18n/sl-SI/plugin.json +++ b/web/i18n/sl-SI/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Odstrani vtičnik", "action.deleteContentLeft": "Ali želite odstraniti", "action.deleteContentRight": "vtičnik?", + "action.deleteSuccess": "Vtičnik je bil uspešno odstranjen", "action.pluginInfo": "Informacije o vtičniku", "action.usedInApps": "Ta vtičnik se uporablja v {{num}} aplikacijah.", "allCategories": "Vse kategorije", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Namestite", "detailPanel.operation.remove": "Odstrani", "detailPanel.operation.update": "Posodobitev", + "detailPanel.operation.updateTooltip": "Posodobite za dostop do najnovejših modelov.", "detailPanel.operation.viewDetail": "Oglej si podrobnosti", "detailPanel.serviceOk": "Storitve so v redu", "detailPanel.strategyNum": "{{num}} {{strategy}} VKLJUČENO", @@ -231,12 +233,18 @@ "source.local": "Lokalna paketna datoteka", "source.marketplace": "Tržnica", "task.clearAll": "Počisti vse", + "task.errorMsg.github": "Ta vtičnik ni bil samodejno nameščen.\nProsimo, namestite ga iz GitHuba.", + "task.errorMsg.marketplace": "Ta vtičnik ni bil samodejno nameščen.\nProsimo, namestite ga iz Tržnice.", + "task.errorMsg.unknown": "Ta vtičnik ni bil nameščen.\nIzvora vtičnika ni mogoče ugotoviti.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} vtičnikov ni uspelo namestiti, kliknite za ogled", + "task.installFromGithub": "Namestite iz GitHuba", + "task.installFromMarketplace": "Namestite iz Tržnice", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} vtičnikov ni uspelo namestiti", "task.installing": "Nameščanje vtičnikov.", + "task.installingHint": "Nameščanje... To lahko traja nekaj minut.", "task.installingWithError": "Namestitev {{installingLength}} vtičnikov, {{successLength}} uspešnih, {{errorLength}} neuspešnih", "task.installingWithSuccess": "Namestitev {{installingLength}} dodatkov, {{successLength}} uspešnih.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/sl-SI/workflow.json b/web/i18n/sl-SI/workflow.json index f11175d1d4..a20a30753d 100644 --- a/web/i18n/sl-SI/workflow.json +++ b/web/i18n/sl-SI/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "posodabljanje delovnega procesa", "error.startNodeRequired": "Prosimo, najprej dodajte začetni vozel pred {{operation}}", "errorMsg.authRequired": "Zahtevana je avtorizacija", + "errorMsg.configureModel": "Konfiguriraj model", "errorMsg.fieldRequired": "{{field}} je obvezno", "errorMsg.fields.code": "Koda", "errorMsg.fields.model": "Model", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vizijska spremenljivka", "errorMsg.invalidJson": "{{field}} je neveljaven JSON", "errorMsg.invalidVariable": "Neveljavna spremenljivka", + "errorMsg.modelPluginNotInstalled": "Neveljavna spremenljivka. Konfigurirajte model za omogočanje te spremenljivke.", "errorMsg.noValidTool": "{{field}} ni izbranega veljavnega orodja", "errorMsg.rerankModelRequired": "Zahteva se konfigurirana model ponovnega razvrščanja.", "errorMsg.startNodeRequired": "Prosimo, najprej dodajte začetni vozel pred {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Velikost okna", "nodes.common.outputVars": "Izhodne spremenljivke", "nodes.common.pluginNotInstalled": "Vtičnik ni nameščen", + "nodes.common.pluginsNotInstalled": "{{count}} vtičnikov ni nameščenih", "nodes.common.retry.maxRetries": "maksimalno število poskusov", "nodes.common.retry.ms": "ms", "nodes.common.retry.retries": "{{num}} Poskusi", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Kosi", "nodes.knowledgeBase.chunksInputTip": "Vhodna spremenljivka vozlišča podatkovne baze je Chunks. Tip spremenljivke je objekt s specifično JSON shemo, ki mora biti skladna z izbrano strukturo kosov.", "nodes.knowledgeBase.chunksVariableIsRequired": "Spremenljivka Chunks je obvezna", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API ključ ni na voljo", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Krediti so porabljeni", + "nodes.knowledgeBase.embeddingModelIncompatible": "Nezdružljivo", "nodes.knowledgeBase.embeddingModelIsInvalid": "Vdelovalni model ni veljaven", "nodes.knowledgeBase.embeddingModelIsRequired": "Zahteva se vgrajevalni model", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Vdelovalni model ni konfiguriran", "nodes.knowledgeBase.indexMethodIsRequired": "Zahteva se indeksna metoda", + "nodes.knowledgeBase.notConfigured": "Ni konfigurirano", "nodes.knowledgeBase.rerankingModelIsInvalid": "Model prerazvrščanja ni veljaven", "nodes.knowledgeBase.rerankingModelIsRequired": "Potreben je model za ponovno razvrščanje", "nodes.knowledgeBase.retrievalSettingIsRequired": "Zahtevana je nastavitev pridobivanja", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Podpira samo Jinja2", "nodes.templateTransform.inputVars": "Vhodne spremenljivke", "nodes.templateTransform.outputVars.output": "Transformirana vsebina", + "nodes.tool.authorizationRequired": "Zahtevana je avtorizacija", "nodes.tool.authorize": "Pooblasti", "nodes.tool.inputVars": "Vhodne spremenljivke", "nodes.tool.insertPlaceholder1": "Vnesite ali pritisnite", @@ -1062,10 +1071,12 @@ "panel.change": "Spremeni", "panel.changeBlock": "Spremeni vozlišče", "panel.checklist": "Kontrolni seznam", + "panel.checklistDescription": "Rešite naslednje težave pred objavo", "panel.checklistResolved": "Vse težave so rešene", "panel.checklistTip": "Prepričajte se, da so vse težave rešene, preden objavite.", "panel.createdBy": "Ustvarjeno z", "panel.goTo": "Pojdi na", + "panel.goToFix": "Pojdi na popravi", "panel.helpLink": "Pomoč", "panel.maximize": "Maksimiziraj platno", "panel.minimize": "Izhod iz celotnega zaslona", diff --git a/web/i18n/th-TH/app-debug.json b/web/i18n/th-TH/app-debug.json index 53dc629a7f..c36db9ce18 100644 --- a/web/i18n/th-TH/app-debug.json +++ b/web/i18n/th-TH/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "วิ่ง", "inputs.title": "ดีบัก & ดูตัวอย่าง", "inputs.userInputField": "ฟิลด์ป้อนข้อมูลของผู้ใช้", + "manageModels": "จัดการโมเดล", "modelConfig.modeType.chat": "สนทนา", "modelConfig.modeType.completion": "สมบูรณ์", "modelConfig.model": "แบบ", "modelConfig.setTone": "กําหนดน้ําเสียงของการตอบกลับ", "modelConfig.title": "รุ่นและพารามิเตอร์", + "noModelProviderConfigured": "ยังไม่ได้กำหนดค่าผู้ให้บริการโมเดล", + "noModelProviderConfiguredTip": "ติดตั้งหรือกำหนดค่าผู้ให้บริการโมเดลเพื่อเริ่มต้นใช้งาน", + "noModelSelected": "ยังไม่ได้เลือกโมเดล", + "noModelSelectedTip": "กำหนดค่าโมเดลด้านบนเพื่อดำเนินการต่อ", "noResult": "ผลลัพธ์จะแสดงที่นี่", "notSetAPIKey.description": "ยังไม่ได้ตั้งค่าคีย์ผู้ให้บริการ LLM และจําเป็นต้องตั้งค่าก่อนการดีบัก", "notSetAPIKey.settingBtn": "ไปที่การตั้งค่า", diff --git a/web/i18n/th-TH/common.json b/web/i18n/th-TH/common.json index 71c9369599..6eed5eba93 100644 --- a/web/i18n/th-TH/common.json +++ b/web/i18n/th-TH/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "ไม่ได้รับอนุญาต", "modelProvider.buyQuota": "ซื้อโควต้า", "modelProvider.callTimes": "เวลาโทร", + "modelProvider.card.aiCreditsInUse": "กำลังใช้เครดิต AI", + "modelProvider.card.aiCreditsOption": "เครดิต AI", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "ต้องการ API Key", + "modelProvider.card.apiKeyUnavailableFallback": "API Key ไม่พร้อมใช้งาน กำลังใช้เครดิต AI แทน", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "ตรวจสอบการกำหนดค่า API Key ของคุณเพื่อสลับกลับ", "modelProvider.card.buyQuota": "ซื้อโควต้า", "modelProvider.card.callTimes": "เวลาโทร", + "modelProvider.card.creditsExhaustedDescription": "กรุณาอัปเกรดแพ็กเกจหรือกำหนดค่า API Key", + "modelProvider.card.creditsExhaustedFallback": "เครดิต AI หมดแล้ว กำลังใช้ API Key แทน", + "modelProvider.card.creditsExhaustedFallbackDescription": "อัปเกรดแพ็กเกจเพื่อกลับมาใช้เครดิต AI เป็นลำดับแรก", + "modelProvider.card.creditsExhaustedMessage": "เครดิต AI หมดแล้ว", "modelProvider.card.modelAPI": "โมเดล {{modelName}} กำลังใช้คีย์ API", "modelProvider.card.modelNotSupported": "โมเดล {{modelName}} ไม่ได้ติดตั้ง", "modelProvider.card.modelSupported": "โมเดล {{modelName}} กำลังใช้โควต้านี้", + "modelProvider.card.noApiKeysDescription": "เพิ่ม API Key เพื่อเริ่มใช้ข้อมูลรับรองโมเดลของคุณเอง", + "modelProvider.card.noApiKeysFallback": "ไม่มี API Key กำลังใช้เครดิต AI แทน", + "modelProvider.card.noApiKeysTitle": "ยังไม่ได้กำหนดค่า API Key", + "modelProvider.card.noAvailableUsage": "ไม่มีปริมาณการใช้งานที่พร้อมใช้", "modelProvider.card.onTrial": "ทดลองใช้", "modelProvider.card.paid": "จ่าย", "modelProvider.card.priorityUse": "ลําดับความสําคัญในการใช้งาน", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "ลบคีย์ API", "modelProvider.card.tip": "เครดิตข้อความรองรับโมเดลจาก {{modelNames}} จะให้ลำดับความสำคัญกับโควต้าที่ชำระแล้ว โควต้าฟรีจะถูกใช้หลังจากโควต้าที่ชำระแล้วหมด", "modelProvider.card.tokens": "โท เค็น", + "modelProvider.card.unavailable": "ไม่พร้อมใช้งาน", + "modelProvider.card.upgradePlan": "อัปเกรดแพ็กเกจ", + "modelProvider.card.usageLabel": "การใช้งาน", + "modelProvider.card.usagePriority": "ลำดับความสำคัญการใช้งาน", + "modelProvider.card.usagePriorityTip": "ตั้งค่าทรัพยากรที่จะใช้ก่อนเมื่อเรียกใช้โมเดล", "modelProvider.collapse": "ทรุด", "modelProvider.config": "กําหนดค่า", "modelProvider.configLoadBalancing": "กําหนดค่าโหลดบาลานซ์", @@ -387,9 +406,11 @@ "modelProvider.model": "แบบ", "modelProvider.modelAndParameters": "รุ่นและพารามิเตอร์", "modelProvider.modelHasBeenDeprecated": "โมเดลนี้เลิกใช้แล้ว", + "modelProvider.modelSettings": "การตั้งค่าโมเดล", "modelProvider.models": "รุ่น", "modelProvider.modelsNum": "{{num}} รุ่น", "modelProvider.noModelFound": "ไม่พบแบบจําลองสําหรับ {{model}}", + "modelProvider.noneConfigured": "กำหนดค่าโมเดลระบบเริ่มต้นเพื่อเรียกใช้แอปพลิเคชัน", "modelProvider.notConfigured": "โมเดลระบบยังไม่ได้รับการกําหนดค่าอย่างสมบูรณ์ และฟังก์ชันบางอย่างอาจไม่พร้อมใช้งาน", "modelProvider.parameters": "พารามิเตอร์", "modelProvider.parametersInvalidRemoved": "บางพารามิเตอร์ไม่ถูกต้องและถูกนำออก", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "รีเซ็ตเมื่อ {{date}}", "modelProvider.searchModel": "ค้นหารุ่น", "modelProvider.selectModel": "เลือกรุ่นของคุณ", + "modelProvider.selector.aiCredits": "เครดิต AI", + "modelProvider.selector.apiKeyUnavailable": "API Key ไม่พร้อมใช้งาน", + "modelProvider.selector.apiKeyUnavailableTip": "API Key ถูกลบแล้ว กรุณากำหนดค่า API Key ใหม่", + "modelProvider.selector.configure": "กำหนดค่า", + "modelProvider.selector.configureRequired": "ต้องกำหนดค่า", + "modelProvider.selector.creditsExhausted": "เครดิตหมดแล้ว", + "modelProvider.selector.creditsExhaustedTip": "เครดิต AI ของคุณหมดแล้ว กรุณาอัปเกรดแพ็กเกจหรือเพิ่ม API Key", + "modelProvider.selector.disabled": "ปิดใช้งาน", + "modelProvider.selector.discoverMoreInMarketplace": "ค้นพบเพิ่มเติมใน Marketplace", "modelProvider.selector.emptySetting": "โปรดไปที่การตั้งค่าเพื่อกําหนดค่า", "modelProvider.selector.emptyTip": "ไม่มีรุ่นที่พร้อมใช้งาน", + "modelProvider.selector.fromMarketplace": "จาก Marketplace", + "modelProvider.selector.incompatible": "ไม่รองรับ", + "modelProvider.selector.incompatibleTip": "โมเดลนี้ไม่พร้อมใช้งานในเวอร์ชันปัจจุบัน กรุณาเลือกโมเดลอื่นที่พร้อมใช้งาน", + "modelProvider.selector.install": "ติดตั้ง", + "modelProvider.selector.modelProviderSettings": "การตั้งค่าผู้ให้บริการโมเดล", + "modelProvider.selector.noProviderConfigured": "ยังไม่ได้กำหนดค่าผู้ให้บริการโมเดล", + "modelProvider.selector.noProviderConfiguredDesc": "เรียกดู Marketplace เพื่อติดตั้ง หรือกำหนดค่าผู้ให้บริการในการตั้งค่า", + "modelProvider.selector.onlyCompatibleModelsShown": "แสดงเฉพาะโมเดลที่รองรับเท่านั้น", "modelProvider.selector.rerankTip": "โปรดตั้งค่าโมเดล Rerank", "modelProvider.selector.tip": "รุ่นนี้ถูกลบออกแล้ว โปรดเพิ่มรุ่นหรือเลือกรุ่นอื่น", "modelProvider.setupModelFirst": "โปรดตั้งค่าโมเดลของคุณก่อน", diff --git a/web/i18n/th-TH/plugin.json b/web/i18n/th-TH/plugin.json index 1c5a544abc..0545d0d46b 100644 --- a/web/i18n/th-TH/plugin.json +++ b/web/i18n/th-TH/plugin.json @@ -3,6 +3,7 @@ "action.delete": "ลบปลั๊กอิน", "action.deleteContentLeft": "คุณต้องการลบ", "action.deleteContentRight": "ปลั๊กอิน?", + "action.deleteSuccess": "ลบปลั๊กอินสำเร็จแล้ว", "action.pluginInfo": "ข้อมูลปลั๊กอิน", "action.usedInApps": "ปลั๊กอินนี้ถูกใช้ในแอป {{num}}", "allCategories": "หมวดหมู่ทั้งหมด", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "ติดตั้ง", "detailPanel.operation.remove": "ถอด", "detailPanel.operation.update": "อัพเดต", + "detailPanel.operation.updateTooltip": "อัปเดตเพื่อเข้าถึงโมเดลล่าสุด", "detailPanel.operation.viewDetail": "ดูรายละเอียด", "detailPanel.serviceOk": "บริการตกลง", "detailPanel.strategyNum": "{{num}} {{strategy}} รวม", @@ -231,12 +233,18 @@ "source.local": "ไฟล์แพ็คเกจในเครื่อง", "source.marketplace": "ตลาด", "task.clearAll": "ล้างทั้งหมด", + "task.errorMsg.github": "ไม่สามารถติดตั้งปลั๊กอินนี้โดยอัตโนมัติได้\nกรุณาติดตั้งจาก GitHub", + "task.errorMsg.marketplace": "ไม่สามารถติดตั้งปลั๊กอินนี้โดยอัตโนมัติได้\nกรุณาติดตั้งจาก Marketplace", + "task.errorMsg.unknown": "ไม่สามารถติดตั้งปลั๊กอินนี้ได้\nไม่สามารถระบุแหล่งที่มาของปลั๊กอินได้", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} ปลั๊กอินติดตั้งไม่สําเร็จ คลิกเพื่อดู", + "task.installFromGithub": "ติดตั้งจาก GitHub", + "task.installFromMarketplace": "ติดตั้งจาก Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} ปลั๊กอินติดตั้งไม่สําเร็จ", "task.installing": "การติดตั้งปลั๊กอิน", + "task.installingHint": "กำลังติดตั้ง... อาจใช้เวลาสักครู่", "task.installingWithError": "การติดตั้งปลั๊กอิน {{installingLength}}, {{successLength}} สําเร็จ, {{errorLength}} ล้มเหลว", "task.installingWithSuccess": "การติดตั้งปลั๊กอิน {{installingLength}}, {{successLength}} สําเร็จ", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/th-TH/workflow.json b/web/i18n/th-TH/workflow.json index 84bfe5ed97..7819e884c3 100644 --- a/web/i18n/th-TH/workflow.json +++ b/web/i18n/th-TH/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "กำลังปรับปรุงเวิร์กโฟลว์", "error.startNodeRequired": "โปรดเพิ่มโหนดเริ่มต้นก่อน {{operation}}", "errorMsg.authRequired": "ต้องได้รับอนุญาต", + "errorMsg.configureModel": "กำหนดค่าโมเดล", "errorMsg.fieldRequired": "{{field}} เป็นสิ่งจําเป็น", "errorMsg.fields.code": "รหัส", "errorMsg.fields.model": "แบบ", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "ตัวแปรวิสัยทัศน์", "errorMsg.invalidJson": "{{field}} เป็น JSON ไม่ถูกต้อง", "errorMsg.invalidVariable": "ตัวแปรไม่ถูกต้อง", + "errorMsg.modelPluginNotInstalled": "ตัวแปรไม่ถูกต้อง กำหนดค่าโมเดลเพื่อเปิดใช้งานตัวแปรนี้", "errorMsg.noValidTool": "{{field}} ไม่ได้เลือกเครื่องมือที่ถูกต้อง", "errorMsg.rerankModelRequired": "ก่อนเปิด Rerank Model โปรดยืนยันว่าได้กําหนดค่าโมเดลสําเร็จในการตั้งค่า", "errorMsg.startNodeRequired": "โปรดเพิ่มโหนดเริ่มต้นก่อน {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "ขนาดหน้าต่าง", "nodes.common.outputVars": "ตัวแปรเอาต์พุต", "nodes.common.pluginNotInstalled": "ปลั๊กอินไม่ได้ติดตั้ง", + "nodes.common.pluginsNotInstalled": "{{count}} ปลั๊กอินยังไม่ได้ติดตั้ง", "nodes.common.retry.maxRetries": "การลองซ้ําสูงสุด", "nodes.common.retry.ms": "นางสาว", "nodes.common.retry.retries": "{{num}} ลอง", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "ชิ้นส่วน", "nodes.knowledgeBase.chunksInputTip": "ตัวแปรนำเข้าของโหนดฐานความรู้คือ Chunks ตัวแปรประเภทเป็นอ็อบเจ็กต์ที่มี JSON Schema เฉพาะซึ่งต้องสอดคล้องกับโครงสร้างชิ้นส่วนที่เลือกไว้.", "nodes.knowledgeBase.chunksVariableIsRequired": "ตัวแปร Chunks เป็นสิ่งจำเป็น", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key ไม่พร้อมใช้งาน", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "เครดิตหมดแล้ว", + "nodes.knowledgeBase.embeddingModelIncompatible": "ไม่รองรับ", "nodes.knowledgeBase.embeddingModelIsInvalid": "แบบจำลองการฝังไม่ถูกต้อง", "nodes.knowledgeBase.embeddingModelIsRequired": "จำเป็นต้องใช้โมเดลฝัง", + "nodes.knowledgeBase.embeddingModelNotConfigured": "ยังไม่ได้กำหนดค่าโมเดล Embedding", "nodes.knowledgeBase.indexMethodIsRequired": "ต้องใช้วิธีการจัดทําดัชนี", + "nodes.knowledgeBase.notConfigured": "ยังไม่ได้กำหนดค่า", "nodes.knowledgeBase.rerankingModelIsInvalid": "โมเดลการจัดอันดับใหม่ไม่ถูกต้อง", "nodes.knowledgeBase.rerankingModelIsRequired": "จำเป็นต้องมีโมเดลการจัดอันดับใหม่", "nodes.knowledgeBase.retrievalSettingIsRequired": "จําเป็นต้องมีการตั้งค่าการดึงข้อมูล", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "รองรับเฉพาะ Jinja2", "nodes.templateTransform.inputVars": "ตัวแปรอินพุต", "nodes.templateTransform.outputVars.output": "เนื้อหาที่แปลงโฉม", + "nodes.tool.authorizationRequired": "ต้องการการอนุญาต", "nodes.tool.authorize": "อนุญาต", "nodes.tool.inputVars": "ตัวแปรอินพุต", "nodes.tool.insertPlaceholder1": "พิมพ์หรือลงทะเบียน", @@ -1062,10 +1071,12 @@ "panel.change": "เปลี่ยน", "panel.changeBlock": "เปลี่ยนโหนด", "panel.checklist": "ตรวจ สอบ", + "panel.checklistDescription": "กรุณาแก้ไขปัญหาต่อไปนี้ก่อนเผยแพร่", "panel.checklistResolved": "ปัญหาทั้งหมดได้รับการแก้ไขแล้ว", "panel.checklistTip": "ตรวจสอบให้แน่ใจว่าปัญหาทั้งหมดได้รับการแก้ไขแล้วก่อนที่จะเผยแพร่", "panel.createdBy": "สร้างโดย", "panel.goTo": "ไปที่", + "panel.goToFix": "ไปแก้ไข", "panel.helpLink": "วิธีใช้", "panel.maximize": "เพิ่มประสิทธิภาพผ้าใบ", "panel.minimize": "ออกจากโหมดเต็มหน้าจอ", diff --git a/web/i18n/uk-UA/app-debug.json b/web/i18n/uk-UA/app-debug.json index 74dcc72efd..dd7ec2de8f 100644 --- a/web/i18n/uk-UA/app-debug.json +++ b/web/i18n/uk-UA/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "ЗАПУСТИТИ", "inputs.title": "Налагодження та попередній перегляд", "inputs.userInputField": "Поле введення користувача", + "manageModels": "Керування моделями", "modelConfig.modeType.chat": "Чат", "modelConfig.modeType.completion": "Завершення", "modelConfig.model": "Модель", "modelConfig.setTone": "Встановити тон відповідей", "modelConfig.title": "Модель і параметри", + "noModelProviderConfigured": "Постачальник моделей не налаштований", + "noModelProviderConfiguredTip": "Встановіть або налаштуйте постачальника моделей, щоб почати роботу.", + "noModelSelected": "Модель не вибрана", + "noModelSelectedTip": "налаштуйте модель вище, щоб продовжити.", "noResult": "Тут буде відображено вихідні дані.", "notSetAPIKey.description": "Ключ провайдера LLM не встановлено, і його потрібно встановити перед налагодженням.", "notSetAPIKey.settingBtn": "Перейти до налаштувань", diff --git a/web/i18n/uk-UA/common.json b/web/i18n/uk-UA/common.json index 6745644573..806cbede3d 100644 --- a/web/i18n/uk-UA/common.json +++ b/web/i18n/uk-UA/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Несанкціоновано", "modelProvider.buyQuota": "Придбати квоту", "modelProvider.callTimes": "Кількість викликів", + "modelProvider.card.aiCreditsInUse": "Використовуються AI-кредити", + "modelProvider.card.aiCreditsOption": "AI-кредити", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Потрібен API-ключ", + "modelProvider.card.apiKeyUnavailableFallback": "API Key недоступний, використовуються AI-кредити", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Перевірте конфігурацію API-ключа, щоб повернутися до нього", "modelProvider.card.buyQuota": "Придбати квоту", "modelProvider.card.callTimes": "Кількість викликів", + "modelProvider.card.creditsExhaustedDescription": "Будь ласка, оновіть свій план або налаштуйте API-ключ", + "modelProvider.card.creditsExhaustedFallback": "AI-кредити вичерпано, використовується API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Оновіть свій план, щоб відновити пріоритет AI-кредитів.", + "modelProvider.card.creditsExhaustedMessage": "AI-кредити вичерпано", "modelProvider.card.modelAPI": "Моделі {{modelName}} використовують API-ключ.", "modelProvider.card.modelNotSupported": "Моделі {{modelName}} не встановлено.", "modelProvider.card.modelSupported": "Моделі {{modelName}} використовують цю квоту.", + "modelProvider.card.noApiKeysDescription": "Додайте API-ключ, щоб почати використовувати власні облікові дані моделі.", + "modelProvider.card.noApiKeysFallback": "API-ключі відсутні, використовуються AI-кредити", + "modelProvider.card.noApiKeysTitle": "API-ключі ще не налаштовані", + "modelProvider.card.noAvailableUsage": "Немає доступного використання", "modelProvider.card.onTrial": "У пробному періоді", "modelProvider.card.paid": "Оплачено", "modelProvider.card.priorityUse": "Пріоритетне використання", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Видалити ключ API", "modelProvider.card.tip": "Кредити повідомлень підтримують моделі від {{modelNames}}. Пріоритет буде надано оплаченій квоті. Безкоштовна квота буде використовуватися після вичерпання платної квоти.", "modelProvider.card.tokens": "Токени", + "modelProvider.card.unavailable": "Недоступно", + "modelProvider.card.upgradePlan": "оновіть свій план", + "modelProvider.card.usageLabel": "Використання", + "modelProvider.card.usagePriority": "Пріоритет використання", + "modelProvider.card.usagePriorityTip": "Встановіть, який ресурс використовувати першим при запуску моделей.", "modelProvider.collapse": "Згорнути", "modelProvider.config": "Налаштування", "modelProvider.configLoadBalancing": "Балансування навантаження конфігурації", @@ -387,9 +406,11 @@ "modelProvider.model": "Модель", "modelProvider.modelAndParameters": "Модель та параметри", "modelProvider.modelHasBeenDeprecated": "Ця модель вважається застарілою", + "modelProvider.modelSettings": "Налаштування моделі", "modelProvider.models": "Моделі", "modelProvider.modelsNum": "{{num}} моделей", "modelProvider.noModelFound": "Модель для {{model}} не знайдено", + "modelProvider.noneConfigured": "Налаштуйте системну модель за замовчуванням для запуску застосунків", "modelProvider.notConfigured": "Системну модель ще не повністю налаштовано, і деякі функції можуть бути недоступні.", "modelProvider.parameters": "ПАРАМЕТРИ", "modelProvider.parametersInvalidRemoved": "Деякі параметри є недійсними і були видалені", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Скидання {{date}}", "modelProvider.searchModel": "Пошукова модель", "modelProvider.selectModel": "Виберіть свою модель", + "modelProvider.selector.aiCredits": "AI-кредити", + "modelProvider.selector.apiKeyUnavailable": "API Key недоступний", + "modelProvider.selector.apiKeyUnavailableTip": "API-ключ було видалено. Будь ласка, налаштуйте новий API-ключ.", + "modelProvider.selector.configure": "Налаштувати", + "modelProvider.selector.configureRequired": "Потрібне налаштування", + "modelProvider.selector.creditsExhausted": "Кредити вичерпано", + "modelProvider.selector.creditsExhaustedTip": "Ваші AI-кредити вичерпано. Будь ласка, оновіть свій план або додайте API-ключ.", + "modelProvider.selector.disabled": "Вимкнено", + "modelProvider.selector.discoverMoreInMarketplace": "Знайти більше в Marketplace", "modelProvider.selector.emptySetting": "Перейдіть до налаштувань, щоб налаштувати", "modelProvider.selector.emptyTip": "Доступні моделі відсутні", + "modelProvider.selector.fromMarketplace": "З Marketplace", + "modelProvider.selector.incompatible": "Несумісно", + "modelProvider.selector.incompatibleTip": "Ця модель недоступна в поточній версії. Будь ласка, виберіть іншу доступну модель.", + "modelProvider.selector.install": "Встановити", + "modelProvider.selector.modelProviderSettings": "Налаштування постачальника моделей", + "modelProvider.selector.noProviderConfigured": "Постачальник моделей не налаштований", + "modelProvider.selector.noProviderConfiguredDesc": "Перегляньте Marketplace для встановлення або налаштуйте постачальників у параметрах.", + "modelProvider.selector.onlyCompatibleModelsShown": "Показано лише сумісні моделі", "modelProvider.selector.rerankTip": "Будь ласка, налаштуйте модель повторного ранжування", "modelProvider.selector.tip": "Цю модель було видалено. Будь ласка, додайте модель або виберіть іншу.", "modelProvider.setupModelFirst": "Будь ласка, спочатку налаштуйте свою модель", diff --git a/web/i18n/uk-UA/plugin.json b/web/i18n/uk-UA/plugin.json index 98bebf0ca8..603a8226a3 100644 --- a/web/i18n/uk-UA/plugin.json +++ b/web/i18n/uk-UA/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Видалити плагін", "action.deleteContentLeft": "Чи хотіли б ви видалити", "action.deleteContentRight": "плагін?", + "action.deleteSuccess": "Плагін успішно видалено", "action.pluginInfo": "Інформація про плагін", "action.usedInApps": "Цей плагін використовується в додатках {{num}}.", "allCategories": "Всі категорії", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Інсталювати", "detailPanel.operation.remove": "Видалити", "detailPanel.operation.update": "Оновлювати", + "detailPanel.operation.updateTooltip": "Оновіть, щоб отримати доступ до найновіших моделей.", "detailPanel.operation.viewDetail": "Переглянути деталі", "detailPanel.serviceOk": "Сервіс працює", "detailPanel.strategyNum": "{{num}} {{strategy}} ВКЛЮЧЕНІ", @@ -231,12 +233,18 @@ "source.local": "Файл локального пакета", "source.marketplace": "Ринку", "task.clearAll": "Очистити все", + "task.errorMsg.github": "Цей плагін не вдалося встановити автоматично.\nБудь ласка, встановіть його з GitHub.", + "task.errorMsg.marketplace": "Цей плагін не вдалося встановити автоматично.\nБудь ласка, встановіть його з Marketplace.", + "task.errorMsg.unknown": "Не вдалося встановити цей плагін.\nДжерело плагіна не вдалося визначити.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "Плагіни {{errorLength}} не вдалося встановити, натисніть, щоб переглянути", + "task.installFromGithub": "Встановити з GitHub", + "task.installFromMarketplace": "Встановити з Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "Плагіни {{errorLength}} не вдалося встановити", "task.installing": "Встановлення плагінів.", + "task.installingHint": "Встановлення... Це може зайняти кілька хвилин.", "task.installingWithError": "Не вдалося встановити плагіни {{installingLength}}, успіх {{successLength}}, {{errorLength}}", "task.installingWithSuccess": "Встановлення плагінів {{installingLength}}, успіх {{successLength}}.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/uk-UA/workflow.json b/web/i18n/uk-UA/workflow.json index e9365303f2..4fa95f6d57 100644 --- a/web/i18n/uk-UA/workflow.json +++ b/web/i18n/uk-UA/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "оновлення робочого процесу", "error.startNodeRequired": "Будь ласка, спершу додайте стартовий вузол перед {{operation}}", "errorMsg.authRequired": "Потрібна авторизація", + "errorMsg.configureModel": "Налаштуйте модель", "errorMsg.fieldRequired": "{{field}} є обов'язковим", "errorMsg.fields.code": "Код", "errorMsg.fields.model": "Модель", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Змінна зору", "errorMsg.invalidJson": "{{field}} є недійсним JSON", "errorMsg.invalidVariable": "Недійсна змінна", + "errorMsg.modelPluginNotInstalled": "Недійсна змінна. Налаштуйте модель, щоб увімкнути цю змінну.", "errorMsg.noValidTool": "{{field}} не вибрано дійсного інструменту", "errorMsg.rerankModelRequired": "Перед увімкненням Rerank Model, будь ласка, підтвердьте, що модель успішно налаштована в налаштуваннях.", "errorMsg.startNodeRequired": "Будь ласка, спершу додайте стартовий вузол перед {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Розмір вікна", "nodes.common.outputVars": "Змінні виходу", "nodes.common.pluginNotInstalled": "Плагін не встановлений", + "nodes.common.pluginsNotInstalled": "{{count}} плагінів не встановлено", "nodes.common.retry.maxRetries": "Максимальна кількість повторних спроб", "nodes.common.retry.ms": "МС", "nodes.common.retry.retries": "{{num}} Спроб", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Частини", "nodes.knowledgeBase.chunksInputTip": "Вхідна змінна вузла бази знань - це Частини. Тип змінної - об'єкт з певною JSON-схемою, яка повинна відповідати вибраній структурі частин.", "nodes.knowledgeBase.chunksVariableIsRequired": "Змінна chunks є обов'язковою", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API-ключ недоступний", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Кредити вичерпано", + "nodes.knowledgeBase.embeddingModelIncompatible": "Несумісно", "nodes.knowledgeBase.embeddingModelIsInvalid": "Модель вбудовування недійсна", "nodes.knowledgeBase.embeddingModelIsRequired": "Потрібна модель вбудовування", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Модель ембедингу не налаштована", "nodes.knowledgeBase.indexMethodIsRequired": "Обов'язковий індексний метод", + "nodes.knowledgeBase.notConfigured": "Не налаштовано", "nodes.knowledgeBase.rerankingModelIsInvalid": "Модель переналаштування недійсна", "nodes.knowledgeBase.rerankingModelIsRequired": "Потрібна модель повторного ранжування", "nodes.knowledgeBase.retrievalSettingIsRequired": "Потрібне налаштування для отримання", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Підтримує лише Jinja2", "nodes.templateTransform.inputVars": "Вхідні змінні", "nodes.templateTransform.outputVars.output": "Трансформований вміст", + "nodes.tool.authorizationRequired": "Потрібна авторизація", "nodes.tool.authorize": "Уповноважити", "nodes.tool.inputVars": "Вхідні змінні", "nodes.tool.insertPlaceholder1": "Введіть або натисніть", @@ -1062,10 +1071,12 @@ "panel.change": "Змінити", "panel.changeBlock": "Змінити вузол", "panel.checklist": "Контрольний список", + "panel.checklistDescription": "Вирішіть наступні проблеми перед публікацією", "panel.checklistResolved": "Всі проблеми вирішені", "panel.checklistTip": "Переконайтеся, що всі проблеми вирішені перед публікацією", "panel.createdBy": "Створено ", "panel.goTo": "Перейти до", + "panel.goToFix": "Перейти до виправлення", "panel.helpLink": "Довідковий центр", "panel.maximize": "Максимізувати полотно", "panel.minimize": "Вийти з повноекранного режиму", diff --git a/web/i18n/vi-VN/app-debug.json b/web/i18n/vi-VN/app-debug.json index d533a02370..92037650cb 100644 --- a/web/i18n/vi-VN/app-debug.json +++ b/web/i18n/vi-VN/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "CHẠY", "inputs.title": "Gỡ lỗi và xem trước", "inputs.userInputField": "Trường nhập liệu người dùng", + "manageModels": "Quản lý mô hình", "modelConfig.modeType.chat": "Trò chuyện", "modelConfig.modeType.completion": "Hoàn thành", "modelConfig.model": "Mô hình", "modelConfig.setTone": "Thiết lập giọng điệu của phản hồi", "modelConfig.title": "Mô hình và tham số", + "noModelProviderConfigured": "Chưa cấu hình nhà cung cấp mô hình", + "noModelProviderConfiguredTip": "Cài đặt hoặc cấu hình nhà cung cấp mô hình để bắt đầu.", + "noModelSelected": "Chưa chọn mô hình", + "noModelSelectedTip": "cấu hình mô hình ở trên để tiếp tục.", "noResult": "Đầu ra sẽ được hiển thị ở đây.", "notSetAPIKey.description": "Chưa thiết lập khóa API của nhà cung cấp LLM. Cần thiết lập trước khi gỡ lỗi.", "notSetAPIKey.settingBtn": "Đi đến cài đặt", diff --git a/web/i18n/vi-VN/common.json b/web/i18n/vi-VN/common.json index a47ee6da57..820bfdfdab 100644 --- a/web/i18n/vi-VN/common.json +++ b/web/i18n/vi-VN/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "Không có quyền truy cập", "modelProvider.buyQuota": "Mua Quyền lợi", "modelProvider.callTimes": "Số lần gọi", + "modelProvider.card.aiCreditsInUse": "Đang sử dụng AI credits", + "modelProvider.card.aiCreditsOption": "AI credits", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "Yêu cầu API key", + "modelProvider.card.apiKeyUnavailableFallback": "API Key không khả dụng, đang sử dụng AI credits", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "Kiểm tra cấu hình API key để chuyển lại", "modelProvider.card.buyQuota": "Mua Quota", "modelProvider.card.callTimes": "Số lần gọi", + "modelProvider.card.creditsExhaustedDescription": "Vui lòng nâng cấp gói dịch vụ hoặc cấu hình API key", + "modelProvider.card.creditsExhaustedFallback": "AI credits đã hết, đang sử dụng API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "Nâng cấp gói dịch vụ để khôi phục ưu tiên AI credits.", + "modelProvider.card.creditsExhaustedMessage": "AI credits đã hết", "modelProvider.card.modelAPI": "Các mô hình {{modelName}} đang sử dụng Khóa API.", "modelProvider.card.modelNotSupported": "Các mô hình {{modelName}} chưa được cài đặt.", "modelProvider.card.modelSupported": "Các mô hình {{modelName}} đang sử dụng hạn mức này.", + "modelProvider.card.noApiKeysDescription": "Thêm API key để bắt đầu sử dụng thông tin xác thực mô hình của bạn.", + "modelProvider.card.noApiKeysFallback": "Không có API key, sử dụng AI credits thay thế", + "modelProvider.card.noApiKeysTitle": "Chưa cấu hình API key", + "modelProvider.card.noAvailableUsage": "Không có lượt sử dụng khả dụng", "modelProvider.card.onTrial": "Thử nghiệm", "modelProvider.card.paid": "Đã thanh toán", "modelProvider.card.priorityUse": "Ưu tiên sử dụng", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "Remove API Key", "modelProvider.card.tip": "Tín dụng tin nhắn hỗ trợ các mô hình từ {{modelNames}}. Ưu tiên sẽ được trao cho hạn ngạch đã thanh toán. Hạn ngạch miễn phí sẽ được sử dụng sau khi hết hạn ngạch trả phí.", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "Không khả dụng", + "modelProvider.card.upgradePlan": "nâng cấp gói dịch vụ", + "modelProvider.card.usageLabel": "Sử dụng", + "modelProvider.card.usagePriority": "Ưu tiên sử dụng", + "modelProvider.card.usagePriorityTip": "Đặt tài nguyên sử dụng trước khi chạy mô hình.", "modelProvider.collapse": "Thu gọn", "modelProvider.config": "Cấu hình", "modelProvider.configLoadBalancing": "Cấu hình cân bằng tải", @@ -387,9 +406,11 @@ "modelProvider.model": "Mô hình", "modelProvider.modelAndParameters": "Mô hình và Tham số", "modelProvider.modelHasBeenDeprecated": "Mô hình này đã bị phản đối", + "modelProvider.modelSettings": "Cài đặt mô hình", "modelProvider.models": "Mô hình", "modelProvider.modelsNum": "{{num}} Mô hình", "modelProvider.noModelFound": "Không tìm thấy mô hình cho {{model}}", + "modelProvider.noneConfigured": "Cấu hình mô hình hệ thống mặc định để chạy ứng dụng", "modelProvider.notConfigured": "Mô hình hệ thống vẫn chưa được cấu hình hoàn toàn và một số chức năng có thể không khả dụng.", "modelProvider.parameters": "THAM SỐ", "modelProvider.parametersInvalidRemoved": "Một số tham số không hợp lệ và đã được loại bỏ", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "Đặt lại vào {{date}}", "modelProvider.searchModel": "Mô hình tìm kiếm", "modelProvider.selectModel": "Chọn mô hình của bạn", + "modelProvider.selector.aiCredits": "AI credits", + "modelProvider.selector.apiKeyUnavailable": "API Key không khả dụng", + "modelProvider.selector.apiKeyUnavailableTip": "API key đã bị xóa. Vui lòng cấu hình API key mới.", + "modelProvider.selector.configure": "Cấu hình", + "modelProvider.selector.configureRequired": "Cần cấu hình", + "modelProvider.selector.creditsExhausted": "Credits đã hết", + "modelProvider.selector.creditsExhaustedTip": "AI credits của bạn đã hết. Vui lòng nâng cấp gói dịch vụ hoặc thêm API key.", + "modelProvider.selector.disabled": "Đã tắt", + "modelProvider.selector.discoverMoreInMarketplace": "Khám phá thêm trên Marketplace", "modelProvider.selector.emptySetting": "Vui lòng vào cài đặt để cấu hình", "modelProvider.selector.emptyTip": "Không có mô hình khả dụng", + "modelProvider.selector.fromMarketplace": "Từ Marketplace", + "modelProvider.selector.incompatible": "Không tương thích", + "modelProvider.selector.incompatibleTip": "Mô hình này không khả dụng trong phiên bản hiện tại. Vui lòng chọn mô hình khả dụng khác.", + "modelProvider.selector.install": "Cài đặt", + "modelProvider.selector.modelProviderSettings": "Cài đặt nhà cung cấp mô hình", + "modelProvider.selector.noProviderConfigured": "Chưa cấu hình nhà cung cấp mô hình", + "modelProvider.selector.noProviderConfiguredDesc": "Duyệt Marketplace để cài đặt hoặc cấu hình nhà cung cấp trong phần cài đặt.", + "modelProvider.selector.onlyCompatibleModelsShown": "Chỉ hiển thị các mô hình tương thích", "modelProvider.selector.rerankTip": "Vui lòng thiết lập mô hình sắp xếp lại", "modelProvider.selector.tip": "Mô hình này đã bị xóa. Vui lòng thêm một mô hình hoặc chọn mô hình khác.", "modelProvider.setupModelFirst": "Vui lòng thiết lập mô hình của bạn trước", diff --git a/web/i18n/vi-VN/plugin.json b/web/i18n/vi-VN/plugin.json index 40cc20b7b0..1c1b37bba3 100644 --- a/web/i18n/vi-VN/plugin.json +++ b/web/i18n/vi-VN/plugin.json @@ -3,6 +3,7 @@ "action.delete": "Xóa plugin", "action.deleteContentLeft": "Bạn có muốn xóa", "action.deleteContentRight": "plugin?", + "action.deleteSuccess": "Đã xóa plugin thành công", "action.pluginInfo": "Thông tin plugin", "action.usedInApps": "Plugin này đang được sử dụng trong các ứng dụng {{num}}.", "allCategories": "Tất cả các danh mục", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "Cài đặt", "detailPanel.operation.remove": "Triệt", "detailPanel.operation.update": "Cập nhật", + "detailPanel.operation.updateTooltip": "Cập nhật để truy cập các mô hình mới nhất.", "detailPanel.operation.viewDetail": "xem chi tiết", "detailPanel.serviceOk": "Dịch vụ OK", "detailPanel.strategyNum": "{{num}} {{strategy}} BAO GỒM", @@ -231,12 +233,18 @@ "source.local": "Tệp gói cục bộ", "source.marketplace": "Chợ", "task.clearAll": "Xóa tất cả", + "task.errorMsg.github": "Không thể cài đặt plugin này tự động.\nVui lòng cài đặt từ GitHub.", + "task.errorMsg.marketplace": "Không thể cài đặt plugin này tự động.\nVui lòng cài đặt từ Marketplace.", + "task.errorMsg.unknown": "Không thể cài đặt plugin này.\nKhông xác định được nguồn plugin.", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} plugin không cài đặt được, nhấp để xem", + "task.installFromGithub": "Cài đặt từ GitHub", + "task.installFromMarketplace": "Cài đặt từ Marketplace", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} plugin không cài đặt được", "task.installing": "Đang cài đặt plugin.", + "task.installingHint": "Đang cài đặt... Quá trình này có thể mất vài phút.", "task.installingWithError": "Cài đặt {{installingLength}} plugins, {{successLength}} thành công, {{errorLength}} không thành công", "task.installingWithSuccess": "Cài đặt {{installingLength}} plugins, {{successLength}} thành công.", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/vi-VN/workflow.json b/web/i18n/vi-VN/workflow.json index 3bdd86e231..3a8bdbaaf1 100644 --- a/web/i18n/vi-VN/workflow.json +++ b/web/i18n/vi-VN/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "cập nhật quy trình công việc", "error.startNodeRequired": "Vui lòng thêm một nút bắt đầu trước {{operation}}", "errorMsg.authRequired": "Yêu cầu xác thực", + "errorMsg.configureModel": "Cấu hình mô hình", "errorMsg.fieldRequired": "{{field}} là bắt buộc", "errorMsg.fields.code": "Mã", "errorMsg.fields.model": "Mô hình", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Biến tầm nhìn", "errorMsg.invalidJson": "{{field}} là JSON không hợp lệ", "errorMsg.invalidVariable": "Biến không hợp lệ", + "errorMsg.modelPluginNotInstalled": "Biến không hợp lệ. Cấu hình mô hình để kích hoạt biến này.", "errorMsg.noValidTool": "{{field}} không chọn công cụ hợp lệ nào", "errorMsg.rerankModelRequired": "Trước khi bật Mô hình xếp hạng lại, vui lòng xác nhận rằng mô hình đã được định cấu hình thành công trong cài đặt.", "errorMsg.startNodeRequired": "Vui lòng thêm một nút bắt đầu trước {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "Kích thước cửa sổ", "nodes.common.outputVars": "Biến đầu ra", "nodes.common.pluginNotInstalled": "Plugin chưa được cài đặt", + "nodes.common.pluginsNotInstalled": "{{count}} plugin chưa được cài đặt", "nodes.common.retry.maxRetries": "Số lần thử lại tối đa", "nodes.common.retry.ms": "Ms", "nodes.common.retry.retries": "{{num}} Thử lại", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "Mảnh", "nodes.knowledgeBase.chunksInputTip": "Biến đầu vào của nút cơ sở tri thức là Chunks. Loại biến là một đối tượng với một JSON Schema cụ thể mà phải nhất quán với cấu trúc chunk đã chọn.", "nodes.knowledgeBase.chunksVariableIsRequired": "Biến Chunks là bắt buộc", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API key không khả dụng", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "Credits đã hết", + "nodes.knowledgeBase.embeddingModelIncompatible": "Không tương thích", "nodes.knowledgeBase.embeddingModelIsInvalid": "Mô hình nhúng không hợp lệ", "nodes.knowledgeBase.embeddingModelIsRequired": "Cần có mô hình nhúng", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Chưa cấu hình mô hình nhúng", "nodes.knowledgeBase.indexMethodIsRequired": "Phương pháp chỉ mục là bắt buộc", + "nodes.knowledgeBase.notConfigured": "Chưa cấu hình", "nodes.knowledgeBase.rerankingModelIsInvalid": "Mô hình xếp hạng lại không hợp lệ", "nodes.knowledgeBase.rerankingModelIsRequired": "Cần có mô hình sắp xếp lại", "nodes.knowledgeBase.retrievalSettingIsRequired": "Cài đặt truy xuất là bắt buộc", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "Chỉ hỗ trợ Jinja2", "nodes.templateTransform.inputVars": "Biến đầu vào", "nodes.templateTransform.outputVars.output": "Nội dung chuyển đổi", + "nodes.tool.authorizationRequired": "Yêu cầu ủy quyền", "nodes.tool.authorize": "Ủy quyền", "nodes.tool.inputVars": "Biến đầu vào", "nodes.tool.insertPlaceholder1": "Gõ hoặc nhấn", @@ -1062,10 +1071,12 @@ "panel.change": "Thay đổi", "panel.changeBlock": "Thay đổi Node", "panel.checklist": "Danh sách kiểm tra", + "panel.checklistDescription": "Giải quyết các vấn đề sau trước khi xuất bản", "panel.checklistResolved": "Tất cả các vấn đề đã được giải quyết", "panel.checklistTip": "Đảm bảo rằng tất cả các vấn đề đã được giải quyết trước khi xuất bản", "panel.createdBy": "Tạo bởi ", "panel.goTo": "Đi tới", + "panel.goToFix": "Đi đến sửa lỗi", "panel.helpLink": "Trung tâm trợ giúp", "panel.maximize": "Tối đa hóa Canvas", "panel.minimize": "Thoát chế độ toàn màn hình", diff --git a/web/i18n/zh-Hant/app-debug.json b/web/i18n/zh-Hant/app-debug.json index ab7286691f..3058cd8458 100644 --- a/web/i18n/zh-Hant/app-debug.json +++ b/web/i18n/zh-Hant/app-debug.json @@ -235,11 +235,16 @@ "inputs.run": "執行", "inputs.title": "除錯與預覽", "inputs.userInputField": "使用者輸入", + "manageModels": "管理模型", "modelConfig.modeType.chat": "對話型", "modelConfig.modeType.completion": "補全型", "modelConfig.model": "語言模型", "modelConfig.setTone": "模型設定", "modelConfig.title": "模型及引數", + "noModelProviderConfigured": "未配置模型供應商", + "noModelProviderConfiguredTip": "請先安裝或配置模型供應商以開始使用。", + "noModelSelected": "未選擇模型", + "noModelSelectedTip": "請先在上方配置模型以繼續。", "noResult": "輸出將顯示在此處。", "notSetAPIKey.description": "在除錯之前需要設定 LLM 提供者的金鑰。", "notSetAPIKey.settingBtn": "去設定", diff --git a/web/i18n/zh-Hant/common.json b/web/i18n/zh-Hant/common.json index 85ff9cc687..9317f68f82 100644 --- a/web/i18n/zh-Hant/common.json +++ b/web/i18n/zh-Hant/common.json @@ -340,11 +340,25 @@ "modelProvider.auth.unAuthorized": "未經授權", "modelProvider.buyQuota": "購買額度", "modelProvider.callTimes": "呼叫次數", + "modelProvider.card.aiCreditsInUse": "AI 額度使用中", + "modelProvider.card.aiCreditsOption": "AI 額度", + "modelProvider.card.apiKeyOption": "API Key", + "modelProvider.card.apiKeyRequired": "需要配置 API Key", + "modelProvider.card.apiKeyUnavailableFallback": "API Key 不可用,正在使用 AI 額度", + "modelProvider.card.apiKeyUnavailableFallbackDescription": "檢查你的 API Key 配置以切換回來", "modelProvider.card.buyQuota": "購買額度", "modelProvider.card.callTimes": "呼叫次數", + "modelProvider.card.creditsExhaustedDescription": "請升級方案或配置 API Key", + "modelProvider.card.creditsExhaustedFallback": "AI 額度已用盡,正在使用 API Key", + "modelProvider.card.creditsExhaustedFallbackDescription": "升級方案以恢復 AI 額度優先使用。", + "modelProvider.card.creditsExhaustedMessage": "AI 額度已用盡", "modelProvider.card.modelAPI": "{{modelName}} 模型正在使用 API Key。", "modelProvider.card.modelNotSupported": "{{modelName}} 模型未安裝。", "modelProvider.card.modelSupported": "{{modelName}} 模型正在使用此配額。", + "modelProvider.card.noApiKeysDescription": "新增 API Key 以使用自有模型憑證。", + "modelProvider.card.noApiKeysFallback": "未配置 API Key,正在使用 AI 額度", + "modelProvider.card.noApiKeysTitle": "尚未配置 API Key", + "modelProvider.card.noAvailableUsage": "無可用額度", "modelProvider.card.onTrial": "試用中", "modelProvider.card.paid": "已購買", "modelProvider.card.priorityUse": "優先使用", @@ -353,6 +367,11 @@ "modelProvider.card.removeKey": "刪除 API 金鑰", "modelProvider.card.tip": "消息額度支持使用 {{modelNames}} 的模型;免費額度會在付費額度用盡後才會消耗。", "modelProvider.card.tokens": "Tokens", + "modelProvider.card.unavailable": "不可用", + "modelProvider.card.upgradePlan": "升級方案", + "modelProvider.card.usageLabel": "用量", + "modelProvider.card.usagePriority": "使用優先順序", + "modelProvider.card.usagePriorityTip": "設定執行模型時優先使用的資源。", "modelProvider.collapse": "收起", "modelProvider.config": "配置", "modelProvider.configLoadBalancing": "配置負載均衡", @@ -387,9 +406,11 @@ "modelProvider.model": "模型", "modelProvider.modelAndParameters": "模型及引數", "modelProvider.modelHasBeenDeprecated": "此模型已棄用", + "modelProvider.modelSettings": "模型設定", "modelProvider.models": "模型列表", "modelProvider.modelsNum": "{{num}} 個模型", "modelProvider.noModelFound": "找不到模型 {{model}}", + "modelProvider.noneConfigured": "配置預設系統模型以執行應用", "modelProvider.notConfigured": "系統模型尚未完全配置,部分功能可能無法使用。", "modelProvider.parameters": "引數", "modelProvider.parametersInvalidRemoved": "一些參數無效,已被移除", @@ -403,8 +424,25 @@ "modelProvider.resetDate": "於 {{date}} 重置", "modelProvider.searchModel": "搜尋模型", "modelProvider.selectModel": "選擇您的模型", + "modelProvider.selector.aiCredits": "AI 額度", + "modelProvider.selector.apiKeyUnavailable": "API Key 不可用", + "modelProvider.selector.apiKeyUnavailableTip": "API Key 已被移除,請重新配置 API Key。", + "modelProvider.selector.configure": "配置", + "modelProvider.selector.configureRequired": "需要配置", + "modelProvider.selector.creditsExhausted": "額度已用盡", + "modelProvider.selector.creditsExhaustedTip": "AI 額度已用盡,請升級方案或新增 API Key。", + "modelProvider.selector.disabled": "已停用", + "modelProvider.selector.discoverMoreInMarketplace": "在插件市場探索更多", "modelProvider.selector.emptySetting": "請前往設定進行配置", "modelProvider.selector.emptyTip": "無可用模型", + "modelProvider.selector.fromMarketplace": "從插件市場安裝", + "modelProvider.selector.incompatible": "不相容", + "modelProvider.selector.incompatibleTip": "此模型在目前版本中不可用,請選擇其他可用模型。", + "modelProvider.selector.install": "安裝", + "modelProvider.selector.modelProviderSettings": "模型供應商設定", + "modelProvider.selector.noProviderConfigured": "未配置模型供應商", + "modelProvider.selector.noProviderConfiguredDesc": "前往插件市場安裝,或在設定中配置供應商。", + "modelProvider.selector.onlyCompatibleModelsShown": "僅顯示相容的模型", "modelProvider.selector.rerankTip": "請設定 Rerank 模型", "modelProvider.selector.tip": "該模型已被刪除。請添模型或選擇其他模型。", "modelProvider.setupModelFirst": "請先設定您的模型", diff --git a/web/i18n/zh-Hant/plugin.json b/web/i18n/zh-Hant/plugin.json index 20630f41ab..4e7d577aaa 100644 --- a/web/i18n/zh-Hant/plugin.json +++ b/web/i18n/zh-Hant/plugin.json @@ -3,6 +3,7 @@ "action.delete": "刪除插件", "action.deleteContentLeft": "是否要刪除", "action.deleteContentRight": "插件?", + "action.deleteSuccess": "插件移除成功", "action.pluginInfo": "插件資訊", "action.usedInApps": "此插件正在 {{num}} 個應用程式中使用。", "allCategories": "全部分類", @@ -114,6 +115,7 @@ "detailPanel.operation.install": "安裝", "detailPanel.operation.remove": "刪除", "detailPanel.operation.update": "更新", + "detailPanel.operation.updateTooltip": "更新以取得最新模型。", "detailPanel.operation.viewDetail": "查看詳情", "detailPanel.serviceOk": "服務正常", "detailPanel.strategyNum": "{{num}} {{strategy}} 包括", @@ -231,12 +233,18 @@ "source.local": "本地包檔", "source.marketplace": "市場", "task.clearAll": "全部清除", + "task.errorMsg.github": "此插件無法自動安裝。\n請從 GitHub 安裝。", + "task.errorMsg.marketplace": "此插件無法自動安裝。\n請從插件市場安裝。", + "task.errorMsg.unknown": "此插件無法安裝。\n無法識別插件來源。", "task.errorPlugins": "Failed to Install Plugins", "task.installError": "{{errorLength}} 個插件安裝失敗,點擊查看", + "task.installFromGithub": "從 GitHub 安裝", + "task.installFromMarketplace": "從插件市場安裝", "task.installSuccess": "{{successLength}} plugins installed successfully", "task.installed": "Installed", "task.installedError": "{{errorLength}} 個插件安裝失敗", "task.installing": "正在安裝插件。", + "task.installingHint": "正在安裝……可能需要幾分鐘。", "task.installingWithError": "安裝 {{installingLength}} 個插件,{{successLength}} 成功,{{errorLength}} 失敗", "task.installingWithSuccess": "安裝 {{installingLength}} 個插件,{{successLength}} 成功。", "task.runningPlugins": "Installing Plugins", diff --git a/web/i18n/zh-Hant/workflow.json b/web/i18n/zh-Hant/workflow.json index 8a3c703f55..b739984977 100644 --- a/web/i18n/zh-Hant/workflow.json +++ b/web/i18n/zh-Hant/workflow.json @@ -305,6 +305,7 @@ "error.operations.updatingWorkflow": "更新工作流程", "error.startNodeRequired": "請先新增一個起始節點,再執行 {{operation}}", "errorMsg.authRequired": "請先授權", + "errorMsg.configureModel": "請配置模型", "errorMsg.fieldRequired": "{{field}} 不能為空", "errorMsg.fields.code": "程式碼", "errorMsg.fields.model": "模型", @@ -314,6 +315,7 @@ "errorMsg.fields.visionVariable": "Vision Variable", "errorMsg.invalidJson": "{{field}} 是非法的 JSON", "errorMsg.invalidVariable": "無效的變數", + "errorMsg.modelPluginNotInstalled": "無效的變數。請配置模型以啟用此變數。", "errorMsg.noValidTool": "{{field}} 未選擇有效工具", "errorMsg.rerankModelRequired": "在開啟 Rerank 模型之前,請在設置中確認模型配置成功。", "errorMsg.startNodeRequired": "請先新增一個起始節點,再執行 {{operation}}", @@ -444,6 +446,7 @@ "nodes.common.memory.windowSize": "記憶窗口", "nodes.common.outputVars": "輸出變數", "nodes.common.pluginNotInstalled": "插件未安裝", + "nodes.common.pluginsNotInstalled": "{{count}} 個插件未安裝", "nodes.common.retry.maxRetries": "最大重試次數", "nodes.common.retry.ms": "毫秒", "nodes.common.retry.retries": "{{num}}重試", @@ -686,9 +689,14 @@ "nodes.knowledgeBase.chunksInput": "區塊", "nodes.knowledgeBase.chunksInputTip": "知識庫節點的輸入變數是 Chunks。該變數類型是一個物件,具有特定的 JSON Schema,必須與所選的塊結構一致。", "nodes.knowledgeBase.chunksVariableIsRequired": "Chunks 變數是必需的", + "nodes.knowledgeBase.embeddingModelApiKeyUnavailable": "API Key 不可用", + "nodes.knowledgeBase.embeddingModelCreditsExhausted": "額度已用盡", + "nodes.knowledgeBase.embeddingModelIncompatible": "不相容", "nodes.knowledgeBase.embeddingModelIsInvalid": "嵌入模型無效", "nodes.knowledgeBase.embeddingModelIsRequired": "需要嵌入模型", + "nodes.knowledgeBase.embeddingModelNotConfigured": "Embedding 模型未配置", "nodes.knowledgeBase.indexMethodIsRequired": "索引方法是必填的", + "nodes.knowledgeBase.notConfigured": "未配置", "nodes.knowledgeBase.rerankingModelIsInvalid": "重排序模型無效", "nodes.knowledgeBase.rerankingModelIsRequired": "需要重新排序模型", "nodes.knowledgeBase.retrievalSettingIsRequired": "需要檢索設定", @@ -872,6 +880,7 @@ "nodes.templateTransform.codeSupportTip": "只支持 Jinja2", "nodes.templateTransform.inputVars": "輸入變數", "nodes.templateTransform.outputVars.output": "轉換後內容", + "nodes.tool.authorizationRequired": "需要授權", "nodes.tool.authorize": "授權", "nodes.tool.inputVars": "輸入變數", "nodes.tool.insertPlaceholder1": "輸入或按壓", @@ -1062,10 +1071,12 @@ "panel.change": "更改", "panel.changeBlock": "更改節點", "panel.checklist": "檢查清單", + "panel.checklistDescription": "發佈前請解決以下問題", "panel.checklistResolved": "所有問題均已解決", "panel.checklistTip": "發佈前確保所有問題均已解決", "panel.createdBy": "作者", "panel.goTo": "前往", + "panel.goToFix": "前往修復", "panel.helpLink": "查看幫助文件", "panel.maximize": "最大化畫布", "panel.minimize": "退出全螢幕", diff --git a/web/models/common.ts b/web/models/common.ts index 62a543672b..9c8866b25b 100644 --- a/web/models/common.ts +++ b/web/models/common.ts @@ -220,10 +220,6 @@ export type DataSources = { sources: DataSourceItem[] } -export type GithubRepo = { - stargazers_count: number -} - export type PluginProvider = { tool_name: string is_enabled: boolean diff --git a/web/next/server.ts b/web/next/server.ts deleted file mode 100644 index 037538be96..0000000000 --- a/web/next/server.ts +++ /dev/null @@ -1,2 +0,0 @@ -export { NextResponse } from 'next/server' -export type { NextRequest } from 'next/server' diff --git a/web/package.json b/web/package.json index 1b573d84eb..65372ef5f5 100644 --- a/web/package.json +++ b/web/package.json @@ -1,7 +1,7 @@ { "name": "dify-web", "type": "module", - "version": "1.13.2", + "version": "1.13.3", "private": true, "packageManager": "pnpm@10.32.1", "imports": { @@ -74,8 +74,6 @@ "@lexical/text": "0.42.0", "@lexical/utils": "0.42.0", "@monaco-editor/react": "4.7.0", - "@octokit/core": "7.0.6", - "@octokit/request-error": "7.1.0", "@orpc/client": "1.13.9", "@orpc/contract": "1.13.9", "@orpc/openapi-client": "1.13.9", @@ -217,18 +215,17 @@ "eslint-plugin-better-tailwindcss": "4.3.2", "eslint-plugin-hyoban": "0.14.1", "eslint-plugin-markdown-preferences": "0.40.3", + "eslint-plugin-no-barrel-files": "1.2.2", "eslint-plugin-react-hooks": "7.0.1", "eslint-plugin-react-refresh": "0.5.2", "eslint-plugin-sonarjs": "4.0.2", "eslint-plugin-storybook": "10.3.1", + "happy-dom": "20.8.9", "hono": "4.12.8", "husky": "9.1.7", "iconify-import-svg": "0.1.2", - "jsdom": "29.0.1", - "jsdom-testing-mocks": "1.16.0", "knip": "6.0.2", "lint-staged": "16.4.0", - "nock": "14.0.11", "postcss": "8.5.8", "postcss-js": "5.1.0", "react-server-dom-webpack": "19.2.4", @@ -278,6 +275,8 @@ "object.values": "npm:@nolyfill/object.values@^1.0.44", "pbkdf2": "~3.1.5", "pbkdf2@<3.1.3": "3.1.3", + "picomatch@<2.3.2": "2.3.2", + "picomatch@>=4.0.0 <4.0.4": "4.0.4", "prismjs": "~1.30", "prismjs@<1.30.0": "1.30.0", "rollup@>=4.0.0 <4.59.0": "4.59.0", @@ -285,6 +284,7 @@ "safe-regex-test": "npm:@nolyfill/safe-regex-test@^1.0.44", "safer-buffer": "npm:@nolyfill/safer-buffer@^1.0.44", "side-channel": "npm:@nolyfill/side-channel@^1.0.44", + "smol-toml@<1.6.1": "1.6.1", "solid-js": "1.9.11", "string-width": "~8.2.0", "string.prototype.includes": "npm:@nolyfill/string.prototype.includes@^1.0.44", @@ -298,6 +298,7 @@ "vite": "npm:@voidzero-dev/vite-plus-core@0.1.13", "vitest": "npm:@voidzero-dev/vite-plus-test@0.1.13", "which-typed-array": "npm:@nolyfill/which-typed-array@^1.0.44", + "yaml@>=2.0.0 <2.8.3": "2.8.3", "yauzl@<3.2.1": "3.2.1" }, "ignoredBuiltDependencies": [ diff --git a/web/plugins/dev-proxy/server.spec.ts b/web/plugins/dev-proxy/server.spec.ts index 9c950abae0..c57ec8b4fe 100644 --- a/web/plugins/dev-proxy/server.spec.ts +++ b/web/plugins/dev-proxy/server.spec.ts @@ -1,3 +1,6 @@ +/** + * @vitest-environment node + */ import { beforeEach, describe, expect, it, vi } from 'vitest' import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets } from './server' diff --git a/web/pnpm-lock.yaml b/web/pnpm-lock.yaml index 9945c6f893..cd1a8a8556 100644 --- a/web/pnpm-lock.yaml +++ b/web/pnpm-lock.yaml @@ -35,6 +35,8 @@ overrides: object.values: npm:@nolyfill/object.values@^1.0.44 pbkdf2: ~3.1.5 pbkdf2@<3.1.3: 3.1.3 + picomatch@<2.3.2: 2.3.2 + picomatch@>=4.0.0 <4.0.4: 4.0.4 prismjs: ~1.30 prismjs@<1.30.0: 1.30.0 rollup@>=4.0.0 <4.59.0: 4.59.0 @@ -42,6 +44,7 @@ overrides: safe-regex-test: npm:@nolyfill/safe-regex-test@^1.0.44 safer-buffer: npm:@nolyfill/safer-buffer@^1.0.44 side-channel: npm:@nolyfill/side-channel@^1.0.44 + smol-toml@<1.6.1: 1.6.1 solid-js: 1.9.11 string-width: ~8.2.0 string.prototype.includes: npm:@nolyfill/string.prototype.includes@^1.0.44 @@ -55,6 +58,7 @@ overrides: vite: npm:@voidzero-dev/vite-plus-core@0.1.13 vitest: npm:@voidzero-dev/vite-plus-test@0.1.13 which-typed-array: npm:@nolyfill/which-typed-array@^1.0.44 + yaml@>=2.0.0 <2.8.3: 2.8.3 yauzl@<3.2.1: 3.2.1 importers: @@ -109,12 +113,6 @@ importers: '@monaco-editor/react': specifier: 4.7.0 version: 4.7.0(monaco-editor@0.55.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@octokit/core': - specifier: 7.0.6 - version: 7.0.6 - '@octokit/request-error': - specifier: 7.1.0 - version: 7.1.0 '@orpc/client': specifier: 1.13.9 version: 1.13.9 @@ -144,7 +142,7 @@ importers: version: 0.13.11(typescript@5.9.3)(valibot@1.3.0(typescript@5.9.3))(zod@4.3.6) '@tailwindcss/typography': specifier: 0.5.19 - version: 0.5.19(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2)) + version: 0.5.19(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3)) '@tanstack/react-form': specifier: 1.28.5 version: 1.28.5(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -373,13 +371,13 @@ importers: devDependencies: '@antfu/eslint-config': specifier: 7.7.3 - version: 7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(typescript@5.9.3) + version: 7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(typescript@5.9.3) '@chromatic-com/storybook': specifier: 5.0.2 version: 5.0.2(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) '@egoist/tailwindcss-icons': specifier: 1.9.2 - version: 1.9.2(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2)) + version: 1.9.2(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3)) '@eslint-react/eslint-plugin': specifier: 3.0.0 version: 3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) @@ -412,7 +410,7 @@ importers: version: 4.2.0 '@storybook/addon-docs': specifier: 10.3.1 - version: 10.3.1(@types/react@19.2.14)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + version: 10.3.1(@types/react@19.2.14)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/addon-links': specifier: 10.3.1 version: 10.3.1(react@19.2.4)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) @@ -424,7 +422,7 @@ importers: version: 10.3.1(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) '@storybook/nextjs-vite': specifier: 10.3.1 - version: 10.3.1(@babel/core@7.29.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + version: 10.3.1(@babel/core@7.29.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/react': specifier: 10.3.1 version: 10.3.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) @@ -502,13 +500,13 @@ importers: version: 7.0.0-dev.20260322.1 '@vitejs/plugin-react': specifier: 6.0.1 - version: 6.0.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + version: 6.0.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) '@vitejs/plugin-rsc': specifier: 0.5.21 - version: 0.5.21(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) + version: 0.5.21(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) '@vitest/coverage-v8': specifier: 4.1.0 - version: 4.1.0(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + version: 4.1.0(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) agentation: specifier: 2.3.3 version: 2.3.3(react-dom@19.2.4(react@19.2.4))(react@19.2.4) @@ -526,13 +524,16 @@ importers: version: 0.6.0(eslint@10.1.0(jiti@1.21.7)) eslint-plugin-better-tailwindcss: specifier: 4.3.2 - version: 4.3.2(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2))(typescript@5.9.3) + version: 4.3.2(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3))(typescript@5.9.3) eslint-plugin-hyoban: specifier: 0.14.1 version: 0.14.1(eslint@10.1.0(jiti@1.21.7)) eslint-plugin-markdown-preferences: specifier: 0.40.3 version: 0.40.3(@eslint/markdown@7.5.1)(eslint@10.1.0(jiti@1.21.7)) + eslint-plugin-no-barrel-files: + specifier: 1.2.2 + version: 1.2.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) eslint-plugin-react-hooks: specifier: 7.0.1 version: 7.0.1(eslint@10.1.0(jiti@1.21.7)) @@ -545,6 +546,9 @@ importers: eslint-plugin-storybook: specifier: 10.3.1 version: 10.3.1(eslint@10.1.0(jiti@1.21.7))(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + happy-dom: + specifier: 20.8.9 + version: 20.8.9 hono: specifier: 4.12.8 version: 4.12.8 @@ -554,21 +558,12 @@ importers: iconify-import-svg: specifier: 0.1.2 version: 0.1.2 - jsdom: - specifier: 29.0.1 - version: 29.0.1(canvas@3.2.2) - jsdom-testing-mocks: - specifier: 1.16.0 - version: 1.16.0 knip: specifier: 6.0.2 version: 6.0.2 lint-staged: specifier: 16.4.0 version: 16.4.0 - nock: - specifier: 14.0.11 - version: 14.0.11 postcss: specifier: 8.5.8 version: 8.5.8 @@ -586,7 +581,7 @@ importers: version: 10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) tailwindcss: specifier: 3.4.19 - version: 3.4.19(tsx@4.21.0)(yaml@2.8.2) + version: 3.4.19(tsx@4.21.0)(yaml@2.8.3) taze: specifier: 19.10.0 version: 19.10.0 @@ -601,22 +596,22 @@ importers: version: 3.19.3 vinext: specifier: https://pkg.pr.new/vinext@b6a2cac - version: https://pkg.pr.new/vinext@b6a2cac(1a91bf00ec5f7fb5f0ffb625316f9d01) + version: https://pkg.pr.new/vinext@b6a2cac(33c71b051bfc49f90bf5d8b6a8976975) vite: specifier: npm:@voidzero-dev/vite-plus-core@0.1.13 - version: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + version: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' vite-plugin-inspect: specifier: 11.3.3 - version: 11.3.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + version: 11.3.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) vite-plus: specifier: 0.1.13 - version: 0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2) + version: 0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) vitest: specifier: npm:@voidzero-dev/vite-plus-test@0.1.13 - version: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + version: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' vitest-canvas-mock: specifier: 1.1.3 - version: 1.1.3(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + version: 1.1.3(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) packages: @@ -773,12 +768,12 @@ packages: '@antfu/utils@8.1.1': resolution: {integrity: sha512-Mex9nXf9vR6AhcXmMrlz/HVgYYZpVGJ6YlPgwl7UnaFpnshXs6EK/oa5Gpf3CzENMjkvEx2tQtntGnb7UtSTOQ==} - '@asamuzakjp/css-color@5.0.1': - resolution: {integrity: sha512-2SZFvqMyvboVV1d15lMf7XiI3m7SDqXUuKaTymJYLN6dSGadqp+fVojqJlVoMlbZnlTmu3S0TLwLTJpvBMO1Aw==} + '@asamuzakjp/css-color@5.1.1': + resolution: {integrity: sha512-iGWN8E45Ws0XWx3D44Q1t6vX2LqhCKcwfmwBYCDsFrYFS6m4q/Ks61L2veETaLv+ckDC6+dTETJoaAAb7VjLiw==} engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} - '@asamuzakjp/dom-selector@7.0.3': - resolution: {integrity: sha512-Q6mU0Z6bfj6YvnX2k9n0JxiIwrCFN59x/nWmYQnAqP000ruX/yV+5bp/GRcF5T8ncvfwJQ7fgfP74DlpKExILA==} + '@asamuzakjp/dom-selector@7.0.4': + resolution: {integrity: sha512-jXR6x4AcT3eIrS2fSNAwJpwirOkGcd+E7F7CP3zjdTqz9B/2huHOL8YJZBgekKwLML+u7qB/6P1LXQuMScsx0w==} engines: {node: ^20.19.0 || ^22.12.0 || >=24.0.0} '@asamuzakjp/nwsapi@2.3.9': @@ -962,8 +957,8 @@ packages: peerDependencies: '@csstools/css-tokenizer': ^4.0.0 - '@csstools/css-syntax-patches-for-csstree@1.1.1': - resolution: {integrity: sha512-BvqN0AMWNAnLk9G8jnUT77D+mUbY/H2b3uDTvg2isJkHaOufUE2R3AOwxWo7VBQKT1lOdwdvorddo2B/lk64+w==} + '@csstools/css-syntax-patches-for-csstree@1.1.2': + resolution: {integrity: sha512-5GkLzz4prTIpoyeUiIu3iV6CSG3Plo7xRVOFPKI7FVEJ3mZ0A8SwK0XU3Gl7xAkiQ+mDyam+NNp875/C5y+jSA==} peerDependencies: css-tree: ^3.2.1 peerDependenciesMeta: @@ -1685,10 +1680,6 @@ packages: react: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 react-dom: ^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0 - '@mswjs/interceptors@0.41.3': - resolution: {integrity: sha512-cXu86tF4VQVfwz8W1SPbhoRyHJkti6mjH/XJIxp40jhO4j2k1m4KYrEykxqWPkFF3vrK4rgQppBh//AwyGSXPA==} - engines: {node: '>=18'} - '@napi-rs/wasm-runtime@1.1.1': resolution: {integrity: sha512-p64ah1M1ld8xjWv3qbvFwHiFVWrq1yFvV4f7w+mzaqiR4IlSgkqhcRdHwsGgomwzBH51sRY4NEowLxnaBjcW/A==} @@ -1791,45 +1782,6 @@ packages: resolution: {integrity: sha512-y3SvzjuY1ygnzWA4Krwx/WaJAsTMP11DN+e21A8Fa8PW1oDtVB5NSRW7LWurAiS2oKRkuCgcjTYMkBuBkcPCRg==} engines: {node: '>=12.4.0'} - '@octokit/auth-token@6.0.0': - resolution: {integrity: sha512-P4YJBPdPSpWTQ1NU4XYdvHvXJJDxM6YwpS0FZHRgP7YFkdVxsWcpWGy/NVqlAA7PcPCnMacXlRm1y2PFZRWL/w==} - engines: {node: '>= 20'} - - '@octokit/core@7.0.6': - resolution: {integrity: sha512-DhGl4xMVFGVIyMwswXeyzdL4uXD5OGILGX5N8Y+f6W7LhC1Ze2poSNrkF/fedpVDHEEZ+PHFW0vL14I+mm8K3Q==} - engines: {node: '>= 20'} - - '@octokit/endpoint@11.0.3': - resolution: {integrity: sha512-FWFlNxghg4HrXkD3ifYbS/IdL/mDHjh9QcsNyhQjN8dplUoZbejsdpmuqdA76nxj2xoWPs7p8uX2SNr9rYu0Ag==} - engines: {node: '>= 20'} - - '@octokit/graphql@9.0.3': - resolution: {integrity: sha512-grAEuupr/C1rALFnXTv6ZQhFuL1D8G5y8CN04RgrO4FIPMrtm+mcZzFG7dcBm+nq+1ppNixu+Jd78aeJOYxlGA==} - engines: {node: '>= 20'} - - '@octokit/openapi-types@27.0.0': - resolution: {integrity: sha512-whrdktVs1h6gtR+09+QsNk2+FO+49j6ga1c55YZudfEG+oKJVvJLQi3zkOm5JjiUXAagWK2tI2kTGKJ2Ys7MGA==} - - '@octokit/request-error@7.1.0': - resolution: {integrity: sha512-KMQIfq5sOPpkQYajXHwnhjCC0slzCNScLHs9JafXc4RAJI+9f+jNDlBNaIMTvazOPLgb4BnlhGJOTbnN0wIjPw==} - engines: {node: '>= 20'} - - '@octokit/request@10.0.8': - resolution: {integrity: sha512-SJZNwY9pur9Agf7l87ywFi14W+Hd9Jg6Ifivsd33+/bGUQIjNujdFiXII2/qSlN2ybqUHfp5xpekMEjIBTjlSw==} - engines: {node: '>= 20'} - - '@octokit/types@16.0.0': - resolution: {integrity: sha512-sKq+9r1Mm4efXW1FCk7hFSeJo4QKreL/tTbR0rz/qx/r1Oa2VV83LTA/H/MuCOX7uCIJmQVRKBcbmWoySjAnSg==} - - '@open-draft/deferred-promise@2.2.0': - resolution: {integrity: sha512-CecwLWx3rhxVQF6V4bAgPS5t+So2sTbPgAzafKkVizyi7tlwpcFpdFqq+wqF2OwNBmqFuu6tOyouTuxgpMfzmA==} - - '@open-draft/logger@0.3.0': - resolution: {integrity: sha512-X2g45fzhxH238HKO4xbSr7+wBS8Fvw6ixhTDuvLd5mqh6bJJCFAPwU9mPDxbcrRtfxv4u5IHCEH77BmxvXmmxQ==} - - '@open-draft/until@2.1.0': - resolution: {integrity: sha512-U69T3ItWHvLwGg5eJ0n3I62nWuE6ilHlmz7zM0npLBRvPRd7e6NYmg54vvRtP5mZG7kZqZCFVdsTWo7BPtBujg==} - '@orpc/client@1.13.9': resolution: {integrity: sha512-RmD2HDgmGgF6zgHHdybE4zH6QJoHjC+/C3n56yLf+fmWbiZtwnOUETgGCroY6S8aK2fpy6hJ3wZaJUjfWVuGHg==} @@ -3555,6 +3507,12 @@ packages: '@types/unist@3.0.3': resolution: {integrity: sha512-ko/gIFJRv177XgZsZcBwnqJN5x/Gien8qNOn0D5bQU/zAzVf9Zt3BlcUiLqhV9y4ARk0GbT3tnUiPNgnTXzc/Q==} + '@types/whatwg-mimetype@3.0.2': + resolution: {integrity: sha512-c2AKvDT8ToxLIOUlN51gTiHXflsfIFisS4pO7pDPoKouJCESkhZnEy623gwP9laCy5lnLDAw1vAzu2vM2YLOrA==} + + '@types/ws@8.18.1': + resolution: {integrity: sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==} + '@types/yauzl@2.10.3': resolution: {integrity: sha512-oJoftv0LSuaDZE3Le4DbKX+KS9G36NzOeSap90UIK0yMA/NhKJhqlSGtNDORNRaIbQfzjXDrQa0ytJ6mNRGz/Q==} @@ -3778,7 +3736,7 @@ packages: tsx: ^4.8.1 typescript: ^5.0.0 unplugin-unused: ^0.5.0 - yaml: ^2.4.2 + yaml: 2.8.3 peerDependenciesMeta: '@arethetypeswrong/core': optional: true @@ -4127,12 +4085,6 @@ packages: engines: {node: '>=6.0.0'} hasBin: true - before-after-hook@4.0.0: - resolution: {integrity: sha512-q6tR3RPqIB1pMiTRMFcZwuG5T8vwp+vUvEG0vuI6B+Rikh5BfPp2fQ82c925FOs+b0lcFQ8CFrL+KbilfZFhOQ==} - - bezier-easing@2.1.0: - resolution: {integrity: sha512-gbIqZ/eslnUFC1tjEvtz0sgx+xTK20wDnYMIA27VA04R7w6xxXQPZDbibjA9DTWZRA2CXtwHykkVzlCaAJAZig==} - bidi-js@1.0.3: resolution: {integrity: sha512-RKshQI1R3YQ+n9YJz2QQ147P66ELpa1FQEg20Dk8oW9t2KgLbpDLLp9aGZ7y8WHSshDknG0bknqGw5/tyCs5tw==} @@ -4433,9 +4385,6 @@ packages: resolution: {integrity: sha512-3O5QdqgFRUbXvK1x5INf1YkBz1UKSWqrd63vWsum8MNHDBYD5urm3QtxZbKU259OrEXNM26lP/MPY3d1IGkBgA==} engines: {node: '>=16'} - css-mediaquery@0.1.2: - resolution: {integrity: sha512-COtn4EROW5dBGlE/4PiKnh6rZpAPxDeFLaEEwt4i10jpDMFt2EhQGS79QmmrO+iKCHv0PU/HrOWEhijFd1x99Q==} - css-select@5.2.2: resolution: {integrity: sha512-TizTzUddG/xYLA3NXodFM0fSbNizXjOKhqiQQwvhlspadZokn1KDy0NZFS0wuEubIYAV5/c1/lAr0TaaFXEXzw==} @@ -4972,6 +4921,9 @@ packages: peerDependencies: eslint: '>=8.23.0' + eslint-plugin-no-barrel-files@1.2.2: + resolution: {integrity: sha512-DF2bnHuEHClmL1+maBO5TD2HnnRsLj8J69FFtVkjObkELyjCXaWBsk+URJkqBpdOWURlL+raGX9AEpWCAiOV0g==} + eslint-plugin-no-only-tests@3.3.0: resolution: {integrity: sha512-brcKcxGnISN2CcVhXJ/kEQlNa0MEfGRtwKtWA16SkqXHKitaKIMrfemJKLKX1YqDU5C/5JY3PvZXd5jEW04e0Q==} engines: {node: '>=5.0.0'} @@ -5223,9 +5175,6 @@ packages: engines: {node: '>= 10.17.0'} hasBin: true - fast-content-type-parse@3.0.0: - resolution: {integrity: sha512-ZvLdcY8P+N8mGQJahJV5G4U88CSvT1rP8ApL6uETe88MBXrBHAkZlSEySdUlyztF7ccb+Znos3TFqaepHxdhBg==} - fast-deep-equal@3.1.3: resolution: {integrity: sha512-f3qQ9oQy9j2AhBe/H9VC91wLmKBCCU/gDOnKNAYG5hswO7BLKj09Hc5HYNz9cGI++xlpDCIgDaitVs03ATR84Q==} @@ -5262,7 +5211,7 @@ packages: resolution: {integrity: sha512-tIbYtZbucOs0BRGqPJkshJUYdL+SDH7dVM8gjy+ERp3WAUjLEFJE+02kanyHtwjWOnwrKYBiwAmM0p4kLJAnXg==} engines: {node: '>=12.0.0'} peerDependencies: - picomatch: ^3 || ^4 + picomatch: 4.0.4 peerDependenciesMeta: picomatch: optional: true @@ -5407,6 +5356,10 @@ packages: hachure-fill@0.5.2: resolution: {integrity: sha512-3GKBOn+m2LX9iq+JC1064cSFprJY4jL1jCXTcpnfER5HYE2l/4EfWSGzkPa/ZDBmYI0ZOEj5VHV/eKnPGkHuOg==} + happy-dom@20.8.9: + resolution: {integrity: sha512-Tz23LR9T9jOGVZm2x1EPdXqwA37G/owYMxRwU0E4miurAtFsPMQ1d2Jc2okUaSjZqAFz2oEn3FLXC5a0a+siyA==} + engines: {node: '>=20.0.0'} + has-flag@4.0.0: resolution: {integrity: sha512-EykJT/Q1KjTWctppgIAgfSO0tKVuZUjhgMr17kqTumMl6Afv3EISleU7qZUzoXDFTAHTDC4NOoG/ZxU3EvlMPQ==} engines: {node: '>=8'} @@ -5645,9 +5598,6 @@ packages: engines: {node: '>=14.16'} hasBin: true - is-node-process@1.2.0: - resolution: {integrity: sha512-Vg4o6/fqPxIjtxgUH5QLJhwZ7gW5diGCVlXpuUfELC62CuxM1iHcRe51f2W1FDy04Ai4KJkagKjx3XaqyfRKXw==} - is-number@7.0.0: resolution: {integrity: sha512-41Cifkg6e8TylSpdtTpeLVMqvSBEVzTttHvERD741+pnZ8ANv0004MRL43QKPDlK9cGvNp6NZWZUBlbGXYxxng==} engines: {node: '>=0.12.0'} @@ -5741,10 +5691,6 @@ packages: resolution: {integrity: sha512-/2uqY7x6bsrpi3i9LVU6J89352C0rpMk0as8trXxCtvd4kPk1ke/Eyif6wqfSLvoNJqcDG9Vk4UsXgygzCt2xA==} engines: {node: '>=20.0.0'} - jsdom-testing-mocks@1.16.0: - resolution: {integrity: sha512-wLrulXiLpjmcUYOYGEvz4XARkrmdVpyxzdBl9IAMbQ+ib2/UhUTRCn49McdNfXLff2ysGBUms49ZKX0LR1Q0gg==} - engines: {node: '>=14'} - jsdom@29.0.1: resolution: {integrity: sha512-z6JOK5gRO7aMybVq/y/MlIpKh8JIi68FBKMUtKkK2KH/wMSRlCxQ682d08LB9fYXplyY/UXG8P4XXTScmdjApg==} engines: {node: ^20.19.0 || ^22.13.0 || >=24.0.0} @@ -5774,12 +5720,6 @@ packages: json-stable-stringify-without-jsonify@1.0.1: resolution: {integrity: sha512-Bdboy+l7tA3OGW6FjyFHWkP5LuByj1Tk33Ljyq0axyzdk9//JSi2u3fP1QSmd1KNwq6VOKYGlAu87CisVir6Pw==} - json-stringify-safe@5.0.1: - resolution: {integrity: sha512-ZClg6AaYvamvYEE82d3Iyd3vSSIjQ+odgjaTzRuO3s7toCdFKczob2i0zCh7JE8kWn17yvAWhUVxvqGwUalsRA==} - - json-with-bigint@3.5.7: - resolution: {integrity: sha512-7ei3MdAI5+fJPVnKlW77TKNKwQ5ppSzWvhPuSuINT/GYW9ZOC1eRKOuhV9yHG5aEsUPj9BBx5JIekkmoLHxZOw==} - json5@2.2.3: resolution: {integrity: sha512-XmOWe7eyHYH14cLdVPoyg+GOH3rYX++KpzrylJwSW98t3Nk+U8XOl8FWKOgwtzdb8lXGf6zYwDUzeHMWfxasyg==} engines: {node: '>=6'} @@ -6359,10 +6299,6 @@ packages: sass: optional: true - nock@14.0.11: - resolution: {integrity: sha512-u5xUnYE+UOOBA6SpELJheMCtj2Laqx15Vl70QxKo43Wz/6nMHXS7PrEioXLjXAwhmawdEMNImwKCcPhBJWbKVw==} - engines: {node: '>=18.20.0 <20 || >=20.12.1'} - node-abi@3.89.0: resolution: {integrity: sha512-6u9UwL0HlAl21+agMN3YAMXcKByMqwGx+pq+P76vii5f7hTPtKDp08/H9py6DY+cfDw7kQNTGEj/rly3IgbNQA==} engines: {node: '>=10'} @@ -6445,9 +6381,6 @@ packages: resolution: {integrity: sha512-6IpQ7mKUxRcZNLIObR0hz7lxsapSSIYNZJwXPGeF0mTVqGKFIXj1DQcMoT22S3ROcLyY/rz0PWaWZ9ayWmad9g==} engines: {node: '>= 0.8.0'} - outvariant@1.4.3: - resolution: {integrity: sha512-+Sl2UErvtsoajRDKCE5/dBz4DIvHXQQnAxtQTF04OJxY0+DyZXSo5P5Bb7XYWOh81syohlYL24hbDwxedPUJCA==} - oxc-parser@0.120.0: resolution: {integrity: sha512-WyPWZlcIm+Fkte63FGfgFB8mAAk33aH9h5N9lphXVOHSXEBFFsmYdOBedVKly363aWABjZdaj/m9lBfEY4wt+w==} engines: {node: ^20.19.0 || >=22.12.0} @@ -6574,12 +6507,12 @@ packages: picocolors@1.1.1: resolution: {integrity: sha512-xceH2snhtb5M9liqDsmEw56le376mTZkEX/jEb/RxNFyegNul7eNslCXP9FDj/Lcu0X8KEyMceP2ntpaHrDEVA==} - picomatch@2.3.1: - resolution: {integrity: sha512-JU3teHTNjmE2VCGFzuY8EXzCDVwEqB2a8fsIvwaStHhAWJEeVd1o1QD80CU6+ZdEXXSLbSsuLwJjkCBWqRQUVA==} + picomatch@2.3.2: + resolution: {integrity: sha512-V7+vQEJ06Z+c5tSye8S+nHUfI51xoXIXjHQ99cQtKUkQqqO1kO/KCJUfZXuB47h/YBlDhah2H3hdUGXn8ie0oA==} engines: {node: '>=8.6'} - picomatch@4.0.3: - resolution: {integrity: sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==} + picomatch@4.0.4: + resolution: {integrity: sha512-QP88BAKvMam/3NxH6vj2o21R6MjxZUAd6nlwAS/pnGvN9IVLocLHxGYIzFhg6fUQ+5th6P4dv4eW9jX3DSIj7A==} engines: {node: '>=12'} pify@2.3.0: @@ -6649,7 +6582,7 @@ packages: jiti: '>=1.21.0' postcss: '>=8.0.9' tsx: ^4.8.1 - yaml: ^2.4.2 + yaml: 2.8.3 peerDependenciesMeta: jiti: optional: true @@ -6710,10 +6643,6 @@ packages: prop-types@15.8.1: resolution: {integrity: sha512-oj87CgZICdulUohogVAR7AjlC0327U4el4L6eAvOqCeudMDVU0NThNaV+b9Df4dXgSP1gXMTnPdhfe/2qDH5cg==} - propagate@2.0.1: - resolution: {integrity: sha512-vGrhOavPSTz4QVNuBNdcNXePNdNMaO1xj9yBeH1ScQPjk/rhg9sSlCXPhMkFuaNNW/syTvYqsnbIJxMBfRbbag==} - engines: {node: '>= 8'} - property-information@5.6.0: resolution: {integrity: sha512-YUHSPk+A30YPv+0Qf8i9Mbfe/C0hdPXk1s1jPVToV8pk8BQtpw10ct89Eo7OWkutrwqvT0eicAxlOg3dOAu8JA==} @@ -7181,8 +7110,8 @@ packages: resolution: {integrity: sha512-stxByr12oeeOyY2BlviTNQlYV5xOj47GirPr4yA1hE9JCtxfQN0+tVbkxwCtYDQWhEKWFHsEK48ORg5jrouCAg==} engines: {node: '>=20'} - smol-toml@1.6.0: - resolution: {integrity: sha512-4zemZi0HvTnYwLfrpk/CF9LOd9Lt87kAt50GnqhMpyF9U3poDAP2+iukq2bZsO/ufegbYehBkqINbsWxj4l4cw==} + smol-toml@1.6.1: + resolution: {integrity: sha512-dWUG8F5sIIARXih1DTaQAX4SsiTXhInKf1buxdY9DIg4ZYPZK5nGM1VRIYmEbDbsHt7USo99xSLFu5Q1IqTmsg==} engines: {node: '>= 18'} solid-js@1.9.11: @@ -7251,9 +7180,6 @@ packages: react: ^18.0.0 || ^19.0.0 react-dom: ^18.0.0 || ^19.0.0 - strict-event-emitter@0.5.1: - resolution: {integrity: sha512-vMgjE/GGEPEFnhFub6pa4FmJBRBVOLpIII2hvCZ8Kzb7K0hlHo7mQv6xYrBvCL2LtAIBwFUK8wvuJgTVSQ5MFQ==} - string-argv@0.3.2: resolution: {integrity: sha512-aqD2Q0144Z+/RqG52NeHEkZauTAUWJO8c6yTftGJKO3Tja5tUgIfmIl6kExvhtxSDP7fXB6DvzkfMpCd/F3G+Q==} engines: {node: '>=0.6.19'} @@ -7387,6 +7313,10 @@ packages: resolution: {integrity: sha512-g9ljZiwki/LfxmQADO3dEY1CbpmXT5Hm2fJ+QaGKwSXUylMybePR7/67YW7jOrrvjEgL1Fmz5kzyAjWVWLlucg==} engines: {node: '>=6'} + tapable@2.3.2: + resolution: {integrity: sha512-1MOpMXuhGzGL5TTCZFItxCc0AARf1EZFQkGqMm7ERKj8+Hgr5oLvJOVFcC+lRmR8hCe2S3jC4T5D7Vg/d7/fhA==} + engines: {node: '>=6'} + tar-fs@2.1.4: resolution: {integrity: sha512-mDAjwmZdh7LTT6pNleZ05Yt65HC3E+NiQzl672vQG38jIrehtJk/J3mNwIg+vShQPcLF/LV7CMnDW6vjj6sfYQ==} @@ -7606,8 +7536,8 @@ packages: resolution: {integrity: sha512-jxytwMHhsbdpBXxLAcuu0fzlQeXCNnWdDyRHpvWsUl8vd98UwYdl9YTyn8/HcpcJPC3pwUveefsa3zTxyD/ERg==} engines: {node: '>=20.18.1'} - undici@7.24.5: - resolution: {integrity: sha512-3IWdCpjgxp15CbJnsi/Y9TCDE7HWVN19j1hmzVhoAkY/+CJx449tVxT5wZc1Gwg8J+P0LWvzlBzxYRnHJ+1i7Q==} + undici@7.24.6: + resolution: {integrity: sha512-Xi4agocCbRzt0yYMZGMA6ApD7gvtUFaxm4ZmeacWI4cZxaF6C+8I8QfofC20NAePiB/IcvZmzkJ7XPa471AEtA==} engines: {node: '>=20.18.1'} unicode-trie@2.0.0: @@ -7640,9 +7570,6 @@ packages: unist-util-visit@5.1.0: resolution: {integrity: sha512-m+vIdyeCOpdr/QeQCu2EzxX/ohgS8KbnPDgFni4dQsfSCtpz8UqDyY5GjRru8PDKuYn7Fq19j1CQ+nJSsGKOzg==} - universal-user-agent@7.0.3: - resolution: {integrity: sha512-TmnEAEAsBJVZM/AADELsK76llnwcf9vMKuPz8JflO1frO8Lchitr0fNaN9d+Ap0BjKtqWqd/J17qeDnXh8CL2A==} - universalify@2.0.1: resolution: {integrity: sha512-gptHNQghINnc/vTGIk0SOFGFNXw7JVrlRUtConJRlvaw6DuX0wO5Jeko9sWrMBhh+PsYAZ7oXAiOnf/UKogyiw==} engines: {node: '>= 10.0.0'} @@ -7757,7 +7684,7 @@ packages: resolution: {integrity: sha512-KzIbH/9tXat2u30jf+smMwFCsno4wHVdNmzFyL+T/L3UGqqk6JKfVqOFOZEpZSHADH1k40ab6NUIXZq422ov3Q==} vinext@https://pkg.pr.new/vinext@b6a2cac: - resolution: {integrity: sha512-/Jm507qqC1dCOhCaorb9H8/I5JEqkcsiUJw0Wgprg7Znym4eyLUvcWcRLVyM9z22Tm0+O1PugcSDA8oNvbqPuQ==, tarball: https://pkg.pr.new/vinext@b6a2cac} + resolution: {tarball: https://pkg.pr.new/vinext@b6a2cac} version: 0.0.5 engines: {node: '>=22'} hasBin: true @@ -7915,6 +7842,10 @@ packages: engines: {node: '>=18'} deprecated: Use @exodus/bytes instead for a more spec-conformant and faster implementation + whatwg-mimetype@3.0.0: + resolution: {integrity: sha512-nt+N2dzIutVRxARx1nghPKGv1xHikU7HKdfafKkLNLindmPU/ch3U31NOCGGA/dmPcmb1VlofO0vnKAcsm0o/Q==} + engines: {node: '>=12'} + whatwg-mimetype@4.0.0: resolution: {integrity: sha512-QaKxh0eNIi2mE9p2vEdzfagOKHCcj1pJ56EEHGQOVxp8r9/iszLUUV7v89x9O1p/T+NlTM5W7jW6+cz4Fq1YVg==} engines: {node: '>=18'} @@ -7955,6 +7886,18 @@ packages: utf-8-validate: optional: true + ws@8.20.0: + resolution: {integrity: sha512-sAt8BhgNbzCtgGbt2OxmpuryO63ZoDk/sqaB/znQm94T4fCEsy/yV+7CdC1kJhOU9lboAEU7R3kquuycDoibVA==} + engines: {node: '>=10.0.0'} + peerDependencies: + bufferutil: ^4.0.1 + utf-8-validate: '>=5.0.2' + peerDependenciesMeta: + bufferutil: + optional: true + utf-8-validate: + optional: true + wsl-utils@0.1.0: resolution: {integrity: sha512-h3Fbisa2nKGPxCpm89Hk33lBLsnaGBvctQopaBSOW/uIs6FTe1ATyAnKFJrzVs9vpGdsTe73WF3V4lIsk4Gacw==} engines: {node: '>=18'} @@ -7985,8 +7928,8 @@ packages: resolution: {integrity: sha512-h0uDm97wvT2bokfwwTmY6kJ1hp6YDFL0nRHwNKz8s/VD1FH/vvZjAKoMUE+un0eaYBSG7/c6h+lJTP+31tjgTw==} engines: {node: ^20.19.0 || ^22.13.0 || >=24} - yaml@2.8.2: - resolution: {integrity: sha512-mplynKqc1C2hTVYxd0PU2xQAc22TI1vShAYGksCCfxbn/dFwnHTNi1bvYsBTkhdUNtGIf5xNOg938rrSSYvS9A==} + yaml@2.8.3: + resolution: {integrity: sha512-AvbaCLOO2Otw/lW5bmh9d/WEdcDFdQp2Z2ZUH3pX9U2ihyUY0nvLv7J6TrWowklRGPYbB/IuIMfYgxaCPg5Bpg==} engines: {node: '>= 14.6'} hasBin: true @@ -8214,7 +8157,7 @@ snapshots: idb: 8.0.0 tslib: 2.8.1 - '@antfu/eslint-config@7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(typescript@5.9.3)': + '@antfu/eslint-config@7.7.3(@eslint-react/eslint-plugin@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@next/eslint-plugin-next@16.2.1)(@typescript-eslint/rule-tester@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@typescript-eslint/typescript-estree@8.57.1(typescript@5.9.3))(@typescript-eslint/utils@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(@vue/compiler-sfc@3.5.30)(eslint-plugin-react-hooks@7.0.1(eslint@10.1.0(jiti@1.21.7)))(eslint-plugin-react-refresh@0.5.2(eslint@10.1.0(jiti@1.21.7)))(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(typescript@5.9.3)': dependencies: '@antfu/install-pkg': 1.1.0 '@clack/prompts': 1.1.0 @@ -8224,7 +8167,7 @@ snapshots: '@stylistic/eslint-plugin': 5.10.0(eslint@10.1.0(jiti@1.21.7)) '@typescript-eslint/eslint-plugin': 8.57.1(@typescript-eslint/parser@8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) '@typescript-eslint/parser': 8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) - '@vitest/eslint-plugin': 1.6.12(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + '@vitest/eslint-plugin': 1.6.12(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) ansis: 4.2.0 cac: 7.0.0 eslint: 10.1.0(jiti@1.21.7) @@ -8284,23 +8227,26 @@ snapshots: '@antfu/utils@8.1.1': {} - '@asamuzakjp/css-color@5.0.1': + '@asamuzakjp/css-color@5.1.1': dependencies: '@csstools/css-calc': 3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) '@csstools/css-color-parser': 4.0.2(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) '@csstools/css-tokenizer': 4.0.0 lru-cache: 11.2.7 + optional: true - '@asamuzakjp/dom-selector@7.0.3': + '@asamuzakjp/dom-selector@7.0.4': dependencies: '@asamuzakjp/nwsapi': 2.3.9 bidi-js: 1.0.3 css-tree: 3.2.1 is-potential-custom-element-name: 1.0.1 lru-cache: 11.2.7 + optional: true - '@asamuzakjp/nwsapi@2.3.9': {} + '@asamuzakjp/nwsapi@2.3.9': + optional: true '@babel/code-frame@7.29.0': dependencies: @@ -8435,6 +8381,7 @@ snapshots: '@bramus/specificity@2.4.2': dependencies: css-tree: 3.2.1 + optional: true '@chevrotain/cst-dts-gen@11.1.2': dependencies: @@ -8527,12 +8474,14 @@ snapshots: transitivePeerDependencies: - supports-color - '@csstools/color-helpers@6.0.2': {} + '@csstools/color-helpers@6.0.2': + optional: true '@csstools/css-calc@3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0)': dependencies: '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) '@csstools/css-tokenizer': 4.0.0 + optional: true '@csstools/css-color-parser@4.0.2(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0)': dependencies: @@ -8540,16 +8489,20 @@ snapshots: '@csstools/css-calc': 3.1.1(@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0))(@csstools/css-tokenizer@4.0.0) '@csstools/css-parser-algorithms': 4.0.0(@csstools/css-tokenizer@4.0.0) '@csstools/css-tokenizer': 4.0.0 + optional: true '@csstools/css-parser-algorithms@4.0.0(@csstools/css-tokenizer@4.0.0)': dependencies: '@csstools/css-tokenizer': 4.0.0 + optional: true - '@csstools/css-syntax-patches-for-csstree@1.1.1(css-tree@3.2.1)': + '@csstools/css-syntax-patches-for-csstree@1.1.2(css-tree@3.2.1)': optionalDependencies: css-tree: 3.2.1 + optional: true - '@csstools/css-tokenizer@4.0.0': {} + '@csstools/css-tokenizer@4.0.0': + optional: true '@e18e/eslint-plugin@0.2.0(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))': dependencies: @@ -8558,10 +8511,10 @@ snapshots: eslint: 10.1.0(jiti@1.21.7) oxlint: 1.56.0(oxlint-tsgolint@0.17.1) - '@egoist/tailwindcss-icons@1.9.2(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2))': + '@egoist/tailwindcss-icons@1.9.2(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3))': dependencies: '@iconify/utils': 3.1.0 - tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.2) + tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.3) '@emnapi/core@1.9.0': dependencies: @@ -8851,7 +8804,8 @@ snapshots: '@eslint/core': 1.1.1 levn: 0.4.1 - '@exodus/bytes@1.15.0': {} + '@exodus/bytes@1.15.0': + optional: true '@floating-ui/core@1.7.5': dependencies: @@ -9066,11 +9020,11 @@ snapshots: dependencies: minipass: 7.1.3 - '@joshwooding/vite-plugin-react-docgen-typescript@0.6.4(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3)': + '@joshwooding/vite-plugin-react-docgen-typescript@0.6.4(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3)': dependencies: glob: 13.0.6 react-docgen-typescript: 2.4.0(typescript@5.9.3) - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' optionalDependencies: typescript: 5.9.3 @@ -9325,15 +9279,6 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@mswjs/interceptors@0.41.3': - dependencies: - '@open-draft/deferred-promise': 2.2.0 - '@open-draft/logger': 0.3.0 - '@open-draft/until': 2.1.0 - is-node-process: 1.2.0 - outvariant: 1.4.3 - strict-event-emitter: 0.5.1 - '@napi-rs/wasm-runtime@1.1.1': dependencies: '@emnapi/core': 1.9.0 @@ -9400,57 +9345,6 @@ snapshots: '@nolyfill/side-channel@1.0.44': {} - '@octokit/auth-token@6.0.0': {} - - '@octokit/core@7.0.6': - dependencies: - '@octokit/auth-token': 6.0.0 - '@octokit/graphql': 9.0.3 - '@octokit/request': 10.0.8 - '@octokit/request-error': 7.1.0 - '@octokit/types': 16.0.0 - before-after-hook: 4.0.0 - universal-user-agent: 7.0.3 - - '@octokit/endpoint@11.0.3': - dependencies: - '@octokit/types': 16.0.0 - universal-user-agent: 7.0.3 - - '@octokit/graphql@9.0.3': - dependencies: - '@octokit/request': 10.0.8 - '@octokit/types': 16.0.0 - universal-user-agent: 7.0.3 - - '@octokit/openapi-types@27.0.0': {} - - '@octokit/request-error@7.1.0': - dependencies: - '@octokit/types': 16.0.0 - - '@octokit/request@10.0.8': - dependencies: - '@octokit/endpoint': 11.0.3 - '@octokit/request-error': 7.1.0 - '@octokit/types': 16.0.0 - fast-content-type-parse: 3.0.0 - json-with-bigint: 3.5.7 - universal-user-agent: 7.0.3 - - '@octokit/types@16.0.0': - dependencies: - '@octokit/openapi-types': 27.0.0 - - '@open-draft/deferred-promise@2.2.0': {} - - '@open-draft/logger@0.3.0': - dependencies: - is-node-process: 1.2.0 - outvariant: 1.4.3 - - '@open-draft/until@2.1.0': {} - '@orpc/client@1.13.9': dependencies: '@orpc/shared': 1.13.9 @@ -9817,7 +9711,7 @@ snapshots: detect-libc: 2.1.2 is-glob: 4.0.3 node-addon-api: 7.1.1 - picomatch: 4.0.3 + picomatch: 4.0.4 optionalDependencies: '@parcel/watcher-android-arm64': 2.5.6 '@parcel/watcher-darwin-arm64': 2.5.6 @@ -10153,7 +10047,7 @@ snapshots: dependencies: '@types/estree': 1.0.8 estree-walker: 2.0.2 - picomatch: 4.0.3 + picomatch: 4.0.4 optionalDependencies: rollup: 4.59.0 @@ -10311,10 +10205,10 @@ snapshots: '@standard-schema/spec@1.1.0': {} - '@storybook/addon-docs@10.3.1(@types/react@19.2.14)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/addon-docs@10.3.1(@types/react@19.2.14)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: '@mdx-js/react': 3.1.1(@types/react@19.2.14)(react@19.2.4) - '@storybook/csf-plugin': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/csf-plugin': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/icons': 2.0.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@storybook/react-dom-shim': 10.3.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)) react: 19.2.4 @@ -10344,25 +10238,25 @@ snapshots: storybook: 10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - '@storybook/builder-vite@10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/builder-vite@10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@storybook/csf-plugin': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/csf-plugin': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) storybook: 10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - esbuild - rollup - webpack - '@storybook/csf-plugin@10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/csf-plugin@10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: storybook: 10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) unplugin: 2.3.11 optionalDependencies: esbuild: 0.27.2 rollup: 4.59.0 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' webpack: 5.105.4(esbuild@0.27.2)(uglify-js@3.19.3) '@storybook/global@5.0.0': {} @@ -10372,18 +10266,18 @@ snapshots: react: 19.2.4 react-dom: 19.2.4(react@19.2.4) - '@storybook/nextjs-vite@10.3.1(@babel/core@7.29.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/nextjs-vite@10.3.1(@babel/core@7.29.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@storybook/builder-vite': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/builder-vite': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/react': 10.3.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) - '@storybook/react-vite': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/react-vite': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) next: 16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) react: 19.2.4 react-dom: 19.2.4(react@19.2.4) storybook: 10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) styled-jsx: 5.1.6(@babel/core@7.29.0)(react@19.2.4) - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-plugin-storybook-nextjs: 3.2.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vite-plugin-storybook-nextjs: 3.2.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) optionalDependencies: typescript: 5.9.3 transitivePeerDependencies: @@ -10400,11 +10294,11 @@ snapshots: react-dom: 19.2.4(react@19.2.4) storybook: 10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) - '@storybook/react-vite@10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': + '@storybook/react-vite@10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3))': dependencies: - '@joshwooding/vite-plugin-react-docgen-typescript': 0.6.4(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3) + '@joshwooding/vite-plugin-react-docgen-typescript': 0.6.4(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3) '@rollup/pluginutils': 5.3.0(rollup@4.59.0) - '@storybook/builder-vite': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) + '@storybook/builder-vite': 10.3.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(rollup@4.59.0)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) '@storybook/react': 10.3.1(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3) empathic: 2.0.0 magic-string: 0.30.21 @@ -10414,7 +10308,7 @@ snapshots: resolve: 1.22.11 storybook: 10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) tsconfig-paths: 4.2.0 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - esbuild - rollup @@ -10453,7 +10347,7 @@ snapshots: eslint-visitor-keys: 4.2.1 espree: 10.4.0 estraverse: 5.3.0 - picomatch: 4.0.3 + picomatch: 4.0.4 '@svgdotjs/svg.js@3.2.5': {} @@ -10479,10 +10373,10 @@ snapshots: valibot: 1.3.0(typescript@5.9.3) zod: 4.3.6 - '@tailwindcss/typography@0.5.19(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2))': + '@tailwindcss/typography@0.5.19(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3))': dependencies: postcss-selector-parser: 6.0.10 - tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.2) + tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.3) '@tanstack/devtools-client@0.0.6': dependencies: @@ -10490,7 +10384,7 @@ snapshots: '@tanstack/devtools-event-bus@0.4.1': dependencies: - ws: 8.19.0 + ws: 8.20.0 transitivePeerDependencies: - bufferutil - utf-8-validate @@ -10948,6 +10842,12 @@ snapshots: '@types/unist@3.0.3': {} + '@types/whatwg-mimetype@3.0.2': {} + + '@types/ws@8.18.1': + dependencies: + '@types/node': 25.5.0 + '@types/yauzl@2.10.3': dependencies: '@types/node': 25.5.0 @@ -11131,12 +11031,12 @@ snapshots: '@resvg/resvg-wasm': 2.4.0 satori: 0.16.0 - '@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))': + '@vitejs/plugin-react@6.0.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))': dependencies: '@rolldown/pluginutils': 1.0.0-rc.7 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' - '@vitejs/plugin-rsc@0.5.21(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)': + '@vitejs/plugin-rsc@0.5.21(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4)': dependencies: '@rolldown/pluginutils': 1.0.0-rc.5 es-module-lexer: 2.0.0 @@ -11148,12 +11048,12 @@ snapshots: srvx: 0.11.12 strip-literal: 3.1.0 turbo-stream: 3.2.0 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vitefu: 1.1.2(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vitefu: 1.1.2(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) optionalDependencies: react-server-dom-webpack: 19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) - '@vitest/coverage-v8@4.1.0(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))': + '@vitest/coverage-v8@4.1.0(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))': dependencies: '@bcoe/v8-coverage': 1.0.2 '@vitest/utils': 4.1.0 @@ -11165,16 +11065,16 @@ snapshots: obug: 2.1.1 std-env: 4.0.0 tinyrainbow: 3.1.0 - vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' - '@vitest/eslint-plugin@1.6.12(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': + '@vitest/eslint-plugin@1.6.12(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3)': dependencies: '@typescript-eslint/scope-manager': 8.57.1 '@typescript-eslint/utils': 8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) eslint: 10.1.0(jiti@1.21.7) optionalDependencies: typescript: 5.9.3 - vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - supports-color @@ -11210,7 +11110,7 @@ snapshots: convert-source-map: 2.0.0 tinyrainbow: 3.1.0 - '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)': + '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': dependencies: '@oxc-project/runtime': 0.120.0 '@oxc-project/types': 0.120.0 @@ -11225,7 +11125,7 @@ snapshots: terser: 5.46.1 tsx: 4.21.0 typescript: 5.9.3 - yaml: 2.8.2 + yaml: 2.8.3 '@voidzero-dev/vite-plus-darwin-arm64@0.1.13': optional: true @@ -11239,11 +11139,11 @@ snapshots: '@voidzero-dev/vite-plus-linux-x64-gnu@0.1.13': optional: true - '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)': + '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)': dependencies: '@standard-schema/spec': 1.1.0 '@types/chai': 5.2.3 - '@voidzero-dev/vite-plus-core': 0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2) + '@voidzero-dev/vite-plus-core': 0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) es-module-lexer: 1.7.0 obug: 2.1.1 pixelmatch: 7.1.0 @@ -11253,10 +11153,11 @@ snapshots: tinybench: 2.9.0 tinyexec: 1.0.4 tinyglobby: 0.2.15 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' ws: 8.19.0 optionalDependencies: '@types/node': 25.5.0 + happy-dom: 20.8.9 jsdom: 29.0.1(canvas@3.2.2) transitivePeerDependencies: - '@arethetypeswrong/core' @@ -11495,7 +11396,7 @@ snapshots: anymatch@3.1.3: dependencies: normalize-path: 3.0.0 - picomatch: 2.3.1 + picomatch: 2.3.2 are-docs-informative@0.0.2: {} @@ -11553,13 +11454,10 @@ snapshots: baseline-browser-mapping@2.10.8: {} - before-after-hook@4.0.0: {} - - bezier-easing@2.1.0: {} - bidi-js@1.0.3: dependencies: require-from-string: 2.0.2 + optional: true binary-extensions@2.3.0: {} @@ -11851,8 +11749,6 @@ snapshots: css-gradient-parser@0.0.16: {} - css-mediaquery@0.1.2: {} - css-select@5.2.2: dependencies: boolbase: 1.0.0 @@ -11881,6 +11777,7 @@ snapshots: dependencies: mdn-data: 2.27.1 source-map-js: 1.2.1 + optional: true css-what@6.2.2: {} @@ -12086,6 +11983,7 @@ snapshots: whatwg-url: 16.0.1 transitivePeerDependencies: - '@noble/hashes' + optional: true dayjs@1.11.20: {} @@ -12340,7 +12238,7 @@ snapshots: dependencies: eslint: 10.1.0(jiti@1.21.7) - eslint-plugin-better-tailwindcss@4.3.2(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2))(typescript@5.9.3): + eslint-plugin-better-tailwindcss@4.3.2(eslint@10.1.0(jiti@1.21.7))(oxlint@1.56.0(oxlint-tsgolint@0.17.1))(tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3))(typescript@5.9.3): dependencies: '@eslint/css-tree': 3.6.9 '@valibot/to-json-schema': 1.6.0(valibot@1.3.0(typescript@5.9.3)) @@ -12348,7 +12246,7 @@ snapshots: jiti: 2.6.1 synckit: 0.11.12 tailwind-csstree: 0.1.5 - tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.2) + tailwindcss: 3.4.19(tsx@4.21.0)(yaml@2.8.3) tsconfig-paths-webpack-plugin: 4.2.0 valibot: 1.3.0(typescript@5.9.3) optionalDependencies: @@ -12458,6 +12356,14 @@ snapshots: transitivePeerDependencies: - typescript + eslint-plugin-no-barrel-files@1.2.2(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): + dependencies: + '@typescript-eslint/utils': 8.57.1(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3) + transitivePeerDependencies: + - eslint + - supports-color + - typescript + eslint-plugin-no-only-tests@3.3.0: {} eslint-plugin-perfectionist@5.7.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): @@ -12477,7 +12383,7 @@ snapshots: pathe: 2.0.3 pnpm-workspace-yaml: 1.6.0 tinyglobby: 0.2.15 - yaml: 2.8.2 + yaml: 2.8.3 yaml-eslint-parser: 2.0.0 eslint-plugin-react-dom@3.0.0(eslint@10.1.0(jiti@1.21.7))(typescript@5.9.3): @@ -12873,8 +12779,6 @@ snapshots: transitivePeerDependencies: - supports-color - fast-content-type-parse@3.0.0: {} - fast-deep-equal@3.1.3: {} fast-glob@3.3.1: @@ -12915,9 +12819,9 @@ snapshots: dependencies: walk-up-path: 4.0.0 - fdir@6.5.0(picomatch@4.0.3): + fdir@6.5.0(picomatch@4.0.4): optionalDependencies: - picomatch: 4.0.3 + picomatch: 4.0.4 fflate@0.4.8: {} @@ -13027,6 +12931,18 @@ snapshots: hachure-fill@0.5.2: {} + happy-dom@20.8.9: + dependencies: + '@types/node': 25.5.0 + '@types/whatwg-mimetype': 3.0.2 + '@types/ws': 8.18.1 + entities: 7.0.1 + whatwg-mimetype: 3.0.0 + ws: 8.20.0 + transitivePeerDependencies: + - bufferutil + - utf-8-validate + has-flag@4.0.0: {} hast-util-from-dom@5.0.1: @@ -13191,6 +13107,7 @@ snapshots: '@exodus/bytes': 1.15.0 transitivePeerDependencies: - '@noble/hashes' + optional: true html-entities@2.6.0: {} @@ -13325,13 +13242,12 @@ snapshots: dependencies: is-docker: 3.0.0 - is-node-process@1.2.0: {} - is-number@7.0.0: {} is-plain-obj@4.1.0: {} - is-potential-custom-element-name@1.0.1: {} + is-potential-custom-element-name@1.0.1: + optional: true is-reference@3.0.3: dependencies: @@ -13393,17 +13309,12 @@ snapshots: jsdoc-type-pratt-parser@7.1.1: {} - jsdom-testing-mocks@1.16.0: - dependencies: - bezier-easing: 2.1.0 - css-mediaquery: 0.1.2 - jsdom@29.0.1(canvas@3.2.2): dependencies: - '@asamuzakjp/css-color': 5.0.1 - '@asamuzakjp/dom-selector': 7.0.3 + '@asamuzakjp/css-color': 5.1.1 + '@asamuzakjp/dom-selector': 7.0.4 '@bramus/specificity': 2.4.2 - '@csstools/css-syntax-patches-for-csstree': 1.1.1(css-tree@3.2.1) + '@csstools/css-syntax-patches-for-csstree': 1.1.2(css-tree@3.2.1) '@exodus/bytes': 1.15.0 css-tree: 3.2.1 data-urls: 7.0.0 @@ -13415,7 +13326,7 @@ snapshots: saxes: 6.0.0 symbol-tree: 3.2.4 tough-cookie: 6.0.1 - undici: 7.24.5 + undici: 7.24.6 w3c-xmlserializer: 5.0.0 webidl-conversions: 8.0.1 whatwg-mimetype: 5.0.0 @@ -13425,6 +13336,7 @@ snapshots: canvas: 3.2.2 transitivePeerDependencies: - '@noble/hashes' + optional: true jsesc@3.1.0: {} @@ -13438,10 +13350,6 @@ snapshots: json-stable-stringify-without-jsonify@1.0.1: {} - json-stringify-safe@5.0.1: {} - - json-with-bigint@3.5.7: {} - json5@2.2.3: {} jsonc-eslint-parser@3.1.0: @@ -13481,11 +13389,11 @@ snapshots: oxc-parser: 0.120.0 oxc-resolver: 11.19.1 picocolors: 1.1.1 - picomatch: 4.0.3 - smol-toml: 1.6.0 + picomatch: 4.0.4 + smol-toml: 1.6.1 strip-json-comments: 5.0.3 unbash: 2.2.0 - yaml: 2.8.2 + yaml: 2.8.3 zod: 4.3.6 kolorist@1.8.0: {} @@ -13591,10 +13499,10 @@ snapshots: dependencies: commander: 14.0.3 listr2: 9.0.5 - picomatch: 4.0.3 + picomatch: 4.0.4 string-argv: 0.3.2 tinyexec: 1.0.4 - yaml: 2.8.2 + yaml: 2.8.3 listr2@9.0.5: dependencies: @@ -13889,7 +13797,8 @@ snapshots: mdn-data@2.23.0: {} - mdn-data@2.27.1: {} + mdn-data@2.27.1: + optional: true memoize-one@5.2.1: {} @@ -14215,7 +14124,7 @@ snapshots: micromatch@4.0.8: dependencies: braces: 3.0.3 - picomatch: 2.3.1 + picomatch: 2.3.2 mime-db@1.52.0: {} @@ -14326,12 +14235,6 @@ snapshots: - '@babel/core' - babel-plugin-macros - nock@14.0.11: - dependencies: - '@mswjs/interceptors': 0.41.3 - json-stringify-safe: 5.0.1 - propagate: 2.0.1 - node-abi@3.89.0: dependencies: semver: 7.7.4 @@ -14401,8 +14304,6 @@ snapshots: type-check: 0.4.0 word-wrap: 1.2.5 - outvariant@1.4.3: {} - oxc-parser@0.120.0: dependencies: '@oxc-project/types': 0.120.0 @@ -14613,9 +14514,9 @@ snapshots: picocolors@1.1.1: {} - picomatch@2.3.1: {} + picomatch@2.3.2: {} - picomatch@4.0.3: {} + picomatch@4.0.4: {} pify@2.3.0: {} @@ -14645,7 +14546,7 @@ snapshots: pnpm-workspace-yaml@1.6.0: dependencies: - yaml: 2.8.2 + yaml: 2.8.3 points-on-curve@0.2.0: {} @@ -14677,14 +14578,14 @@ snapshots: dependencies: postcss: 8.5.8 - postcss-load-config@6.0.1(jiti@1.21.7)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.2): + postcss-load-config@6.0.1(jiti@1.21.7)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.3): dependencies: lilconfig: 3.1.3 optionalDependencies: jiti: 1.21.7 postcss: 8.5.8 tsx: 4.21.0 - yaml: 2.8.2 + yaml: 2.8.3 postcss-nested@6.2.0(postcss@8.5.8): dependencies: @@ -14752,8 +14653,6 @@ snapshots: object-assign: 4.1.1 react-is: 16.13.1 - propagate@2.0.1: {} - property-information@5.6.0: dependencies: xtend: 4.0.2 @@ -14991,7 +14890,7 @@ snapshots: readdirp@3.6.0: dependencies: - picomatch: 2.3.1 + picomatch: 2.3.2 readdirp@4.1.2: {} @@ -15267,6 +15166,7 @@ snapshots: saxes@6.0.0: dependencies: xmlchars: 2.2.0 + optional: true scheduler@0.27.0: {} @@ -15366,7 +15266,7 @@ snapshots: ansi-styles: 6.2.3 is-fullwidth-code-point: 5.1.0 - smol-toml@1.6.0: {} + smol-toml@1.6.1: {} solid-js@1.9.11: dependencies: @@ -15452,8 +15352,6 @@ snapshots: transitivePeerDependencies: - supports-color - strict-event-emitter@0.5.1: {} - string-argv@0.3.2: {} string-ts@2.3.1: {} @@ -15545,7 +15443,8 @@ snapshots: picocolors: 1.1.1 sax: 1.6.0 - symbol-tree@3.2.4: {} + symbol-tree@3.2.4: + optional: true synckit@0.11.12: dependencies: @@ -15561,7 +15460,7 @@ snapshots: tailwind-merge@3.5.0: {} - tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.2): + tailwindcss@3.4.19(tsx@4.21.0)(yaml@2.8.3): dependencies: '@alloc/quick-lru': 5.2.0 arg: 5.0.2 @@ -15580,7 +15479,7 @@ snapshots: postcss: 8.5.8 postcss-import: 15.1.0(postcss@8.5.8) postcss-js: 4.1.0(postcss@8.5.8) - postcss-load-config: 6.0.1(jiti@1.21.7)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.2) + postcss-load-config: 6.0.1(jiti@1.21.7)(postcss@8.5.8)(tsx@4.21.0)(yaml@2.8.3) postcss-nested: 6.2.0(postcss@8.5.8) postcss-selector-parser: 6.1.2 resolve: 1.22.11 @@ -15591,6 +15490,8 @@ snapshots: tapable@2.3.0: {} + tapable@2.3.2: {} + tar-fs@2.1.4: dependencies: chownr: 1.1.4 @@ -15630,7 +15531,7 @@ snapshots: tinyexec: 1.0.4 tinyglobby: 0.2.15 unconfig: 7.5.0 - yaml: 2.8.2 + yaml: 2.8.3 terser-webpack-plugin@5.4.0(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)): dependencies: @@ -15670,8 +15571,8 @@ snapshots: tinyglobby@0.2.15: dependencies: - fdir: 6.5.0(picomatch@4.0.3) - picomatch: 4.0.3 + fdir: 6.5.0(picomatch@4.0.4) + picomatch: 4.0.4 tinypool@2.1.0: {} @@ -15707,10 +15608,12 @@ snapshots: tough-cookie@6.0.1: dependencies: tldts: 7.0.27 + optional: true tr46@6.0.0: dependencies: punycode: 2.3.1 + optional: true trim-lines@3.0.1: {} @@ -15724,7 +15627,7 @@ snapshots: ts-declaration-location@1.0.7(typescript@5.9.3): dependencies: - picomatch: 4.0.3 + picomatch: 4.0.4 typescript: 5.9.3 ts-dedent@2.2.0: {} @@ -15803,7 +15706,8 @@ snapshots: undici@7.24.0: {} - undici@7.24.5: {} + undici@7.24.6: + optional: true unicode-trie@2.0.0: dependencies: @@ -15857,8 +15761,6 @@ snapshots: unist-util-is: 6.0.1 unist-util-visit-parents: 6.0.2 - universal-user-agent@7.0.3: {} - universalify@2.0.1: {} unpic@4.2.2: {} @@ -15866,13 +15768,13 @@ snapshots: unplugin-utils@0.3.1: dependencies: pathe: 2.0.3 - picomatch: 4.0.3 + picomatch: 4.0.4 unplugin@2.3.11: dependencies: '@jridgewell/remapping': 2.3.5 acorn: 8.16.0 - picomatch: 4.0.3 + picomatch: 4.0.4 webpack-virtual-modules: 0.6.2 update-browserslist-db@1.2.3(browserslist@4.28.1): @@ -15955,36 +15857,36 @@ snapshots: '@types/unist': 3.0.3 vfile-message: 4.0.3 - vinext@https://pkg.pr.new/vinext@b6a2cac(1a91bf00ec5f7fb5f0ffb625316f9d01): + vinext@https://pkg.pr.new/vinext@b6a2cac(33c71b051bfc49f90bf5d8b6a8976975): dependencies: '@unpic/react': 1.0.2(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(react-dom@19.2.4(react@19.2.4))(react@19.2.4) '@vercel/og': 0.8.6 - '@vitejs/plugin-react': 6.0.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + '@vitejs/plugin-react': 6.0.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) magic-string: 0.30.21 react: 19.2.4 react-dom: 19.2.4(react@19.2.4) rsc-html-stream: 0.0.7 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' vite-plugin-commonjs: 0.10.4 - vite-tsconfig-paths: 6.1.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3) + vite-tsconfig-paths: 6.1.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3) optionalDependencies: '@mdx-js/rollup': 3.1.1(rollup@4.59.0) - '@vitejs/plugin-rsc': 0.5.21(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) + '@vitejs/plugin-rsc': 0.5.21(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(react-dom@19.2.4(react@19.2.4))(react-server-dom-webpack@19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)))(react@19.2.4) react-server-dom-webpack: 19.2.4(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) transitivePeerDependencies: - next - supports-color - typescript - vite-dev-rpc@1.1.0(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): + vite-dev-rpc@1.1.0(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): dependencies: birpc: 2.9.0 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-hot-client: 2.1.0(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vite-hot-client: 2.1.0(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) - vite-hot-client@2.1.0(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): + vite-hot-client@2.1.0(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): dependencies: - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' vite-plugin-commonjs@0.10.4: dependencies: @@ -15999,7 +15901,7 @@ snapshots: fast-glob: 3.3.3 magic-string: 0.30.21 - vite-plugin-inspect@11.3.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): + vite-plugin-inspect@11.3.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): dependencies: ansis: 4.2.0 debug: 4.4.3 @@ -16009,12 +15911,12 @@ snapshots: perfect-debounce: 2.1.0 sirv: 3.0.2 unplugin-utils: 0.3.1 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-dev-rpc: 1.1.0(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)) + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vite-dev-rpc: 1.1.0(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)) transitivePeerDependencies: - supports-color - vite-plugin-storybook-nextjs@3.2.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3): + vite-plugin-storybook-nextjs@3.2.3(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(next@16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0))(storybook@10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4))(typescript@5.9.3): dependencies: '@next/env': 16.0.0 image-size: 2.0.2 @@ -16023,17 +15925,17 @@ snapshots: next: 16.2.1(@babel/core@7.29.0)(react-dom@19.2.4(react@19.2.4))(react@19.2.4)(sass@1.98.0) storybook: 10.3.1(@testing-library/dom@10.4.1)(react-dom@19.2.4(react@19.2.4))(react@19.2.4) ts-dedent: 2.2.0 - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' - vite-tsconfig-paths: 5.1.4(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3) + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' + vite-tsconfig-paths: 5.1.4(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3) transitivePeerDependencies: - supports-color - typescript - vite-plus@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2): + vite-plus@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3): dependencies: '@oxc-project/types': 0.120.0 - '@voidzero-dev/vite-plus-core': 0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2) - '@voidzero-dev/vite-plus-test': 0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2) + '@voidzero-dev/vite-plus-core': 0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) + '@voidzero-dev/vite-plus-test': 0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3) cac: 7.0.0 cross-spawn: 7.0.6 oxfmt: 0.41.0 @@ -16075,36 +15977,36 @@ snapshots: - vite - yaml - vite-tsconfig-paths@5.1.4(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3): + vite-tsconfig-paths@5.1.4(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3): dependencies: debug: 4.4.3 globrex: 0.1.2 tsconfck: 3.1.6(typescript@5.9.3) optionalDependencies: - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - supports-color - typescript - vite-tsconfig-paths@6.1.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(typescript@5.9.3): + vite-tsconfig-paths@6.1.1(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(typescript@5.9.3): dependencies: debug: 4.4.3 globrex: 0.1.2 tsconfck: 3.1.6(typescript@5.9.3) - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' transitivePeerDependencies: - supports-color - typescript - vitefu@1.1.2(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): + vitefu@1.1.2(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): optionalDependencies: - vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vite: '@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' - vitest-canvas-mock@1.1.3(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)): + vitest-canvas-mock@1.1.3(@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)): dependencies: cssfontparser: 1.2.1 moo-color: 1.0.3 - vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))(esbuild@0.27.2)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)' + vitest: '@voidzero-dev/vite-plus-test@0.1.13(@types/node@25.5.0)(@voidzero-dev/vite-plus-core@0.1.13(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3))(esbuild@0.27.2)(happy-dom@20.8.9)(jiti@1.21.7)(jsdom@29.0.1(canvas@3.2.2))(sass@1.98.0)(terser@5.46.1)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.3)' void-elements@3.1.0: {} @@ -16140,6 +16042,7 @@ snapshots: w3c-xmlserializer@5.0.0: dependencies: xml-name-validator: 5.0.0 + optional: true walk-up-path@4.0.0: {} @@ -16152,7 +16055,8 @@ snapshots: web-vitals@5.1.0: {} - webidl-conversions@8.0.1: {} + webidl-conversions@8.0.1: + optional: true webpack-sources@3.3.4: {} @@ -16181,7 +16085,7 @@ snapshots: mime-types: 2.1.35 neo-async: 2.6.2 schema-utils: 4.3.3 - tapable: 2.3.0 + tapable: 2.3.2 terser-webpack-plugin: 5.4.0(esbuild@0.27.2)(uglify-js@3.19.3)(webpack@5.105.4(esbuild@0.27.2)(uglify-js@3.19.3)) watchpack: 2.5.1 webpack-sources: 3.3.4 @@ -16194,9 +16098,12 @@ snapshots: dependencies: iconv-lite: 0.6.3 + whatwg-mimetype@3.0.0: {} + whatwg-mimetype@4.0.0: {} - whatwg-mimetype@5.0.0: {} + whatwg-mimetype@5.0.0: + optional: true whatwg-url@16.0.1: dependencies: @@ -16205,6 +16112,7 @@ snapshots: webidl-conversions: 8.0.1 transitivePeerDependencies: - '@noble/hashes' + optional: true which@2.0.2: dependencies: @@ -16222,15 +16130,19 @@ snapshots: ws@8.19.0: {} + ws@8.20.0: {} + wsl-utils@0.1.0: dependencies: is-wsl: 3.1.1 xml-name-validator@4.0.0: {} - xml-name-validator@5.0.0: {} + xml-name-validator@5.0.0: + optional: true - xmlchars@2.2.0: {} + xmlchars@2.2.0: + optional: true xtend@4.0.2: {} @@ -16241,9 +16153,9 @@ snapshots: yaml-eslint-parser@2.0.0: dependencies: eslint-visitor-keys: 5.0.1 - yaml: 2.8.2 + yaml: 2.8.3 - yaml@2.8.2: {} + yaml@2.8.3: {} yauzl@3.2.1: dependencies: diff --git a/web/proxy.ts b/web/proxy.ts index 02513d91b9..af9b290025 100644 --- a/web/proxy.ts +++ b/web/proxy.ts @@ -1,9 +1,11 @@ -import type { NextRequest } from '@/next/server' +// eslint-disable-next-line no-restricted-imports +import type { NextRequest } from 'next/server' import { Buffer } from 'node:buffer' +// eslint-disable-next-line no-restricted-imports +import { NextResponse } from 'next/server' import { env } from '@/env' -import { NextResponse } from '@/next/server' -const NECESSARY_DOMAIN = '*.sentry.io http://localhost:* http://127.0.0.1:* https://analytics.google.com googletagmanager.com *.googletagmanager.com https://www.google-analytics.com https://api.github.com https://api2.amplitude.com *.amplitude.com' +const NECESSARY_DOMAIN = '*.sentry.io http://localhost:* http://127.0.0.1:* https://analytics.google.com googletagmanager.com *.googletagmanager.com https://www.google-analytics.com https://ungh.cc https://api2.amplitude.com *.amplitude.com' const wrapResponseWithXFrameOptions = (response: NextResponse, pathname: string) => { // prevent clickjacking: https://owasp.org/www-community/attacks/Clickjacking diff --git a/web/types/model-provider.ts b/web/types/model-provider.ts index b98cc62441..7be0ccc58f 100644 --- a/web/types/model-provider.ts +++ b/web/types/model-provider.ts @@ -2,12 +2,13 @@ * Model provider quota types - shared type definitions for API responses * These represent the provider identifiers that support paid/trial quotas */ -export enum ModelProviderQuotaGetPaid { - ANTHROPIC = 'langgenius/anthropic/anthropic', - OPENAI = 'langgenius/openai/openai', - // AZURE_OPENAI = 'langgenius/azure_openai/azure_openai', - GEMINI = 'langgenius/gemini/google', - X = 'langgenius/x/x', - DEEPSEEK = 'langgenius/deepseek/deepseek', - TONGYI = 'langgenius/tongyi/tongyi', -} +export const ModelProviderQuotaGetPaid = { + ANTHROPIC: 'langgenius/anthropic/anthropic', + OPENAI: 'langgenius/openai/openai', + // AZURE_OPENAI: 'langgenius/azure_openai/azure_openai', + GEMINI: 'langgenius/gemini/google', + X: 'langgenius/x/x', + DEEPSEEK: 'langgenius/deepseek/deepseek', + TONGYI: 'langgenius/tongyi/tongyi', +} as const +export type ModelProviderQuotaGetPaid = typeof ModelProviderQuotaGetPaid[keyof typeof ModelProviderQuotaGetPaid] diff --git a/web/types/workflow.ts b/web/types/workflow.ts index f8a53c8d7e..5c39246ee0 100644 --- a/web/types/workflow.ts +++ b/web/types/workflow.ts @@ -455,12 +455,13 @@ export type PanelProps = { export type NodeRunResult = NodeTracing // Var Inspect -export enum VarInInspectType { - conversation = 'conversation', - environment = 'env', - node = 'node', - system = 'sys', -} +export const VarInInspectType = { + conversation: 'conversation', + environment: 'env', + node: 'node', + system: 'sys', +} as const +export type VarInInspectType = typeof VarInInspectType[keyof typeof VarInInspectType] export type FullContent = { size_bytes: number diff --git a/web/vite.config.ts b/web/vite.config.ts index 617cae9ab5..28746f81ca 100644 --- a/web/vite.config.ts +++ b/web/vite.config.ts @@ -75,7 +75,8 @@ export default defineConfig(({ mode }) => { // Vitest config test: { - environment: 'jsdom', + pool: 'threads', + environment: 'happy-dom', globals: true, setupFiles: ['./vitest.setup.ts'], coverage: { diff --git a/web/vitest.setup.ts b/web/vitest.setup.ts index e63ea2b54e..ac26ac5d25 100644 --- a/web/vitest.setup.ts +++ b/web/vitest.setup.ts @@ -1,14 +1,8 @@ import { act, cleanup } from '@testing-library/react' -import { mockAnimationsApi, mockResizeObserver } from 'jsdom-testing-mocks' import * as React from 'react' import '@testing-library/jest-dom/vitest' import 'vitest-canvas-mock' -mockResizeObserver() - -// Mock Web Animations API for Headless UI -mockAnimationsApi() - // Suppress act() warnings from @headlessui/react internal Transition component // These warnings are caused by Headless UI's internal async state updates, not our code const originalConsoleError = console.error @@ -77,24 +71,10 @@ if (typeof globalThis.IntersectionObserver === 'undefined') { } } -// Mock Element.scrollIntoView for tests (not available in happy-dom/jsdom) -if (typeof Element !== 'undefined' && !Element.prototype.scrollIntoView) - Element.prototype.scrollIntoView = function () { /* noop */ } - -// Mock DOMRect.fromRect for tests (not available in jsdom) -if (typeof DOMRect !== 'undefined' && typeof (DOMRect as typeof DOMRect & { fromRect?: unknown }).fromRect !== 'function') { - (DOMRect as typeof DOMRect & { fromRect: (rect?: DOMRectInit) => DOMRect }).fromRect = (rect = {}) => new DOMRect( - rect.x ?? 0, - rect.y ?? 0, - rect.width ?? 0, - rect.height ?? 0, - ) -} - afterEach(async () => { // Wrap cleanup in act() to flush pending React scheduler work // This prevents "window is not defined" errors from React 19's scheduler - // which uses setImmediate/MessageChannel that can fire after jsdom cleanup + // which uses setImmediate/MessageChannel that can fire after DOM cleanup await act(async () => { cleanup() }) @@ -131,19 +111,97 @@ vi.mock('@floating-ui/react', async () => { } }) -// mock window.matchMedia -Object.defineProperty(window, 'matchMedia', { - writable: true, - value: vi.fn().mockImplementation(query => ({ - matches: false, - media: query, - onchange: null, - addListener: vi.fn(), // deprecated - removeListener: vi.fn(), // deprecated - addEventListener: vi.fn(), - removeEventListener: vi.fn(), - dispatchEvent: vi.fn(), - })), +vi.mock('@monaco-editor/react', () => { + const createEditorMock = () => { + const focusListeners: Array<() => void> = [] + const blurListeners: Array<() => void> = [] + + return { + getContentHeight: vi.fn(() => 56), + onDidFocusEditorText: vi.fn((listener: () => void) => { + focusListeners.push(listener) + return { dispose: vi.fn() } + }), + onDidBlurEditorText: vi.fn((listener: () => void) => { + blurListeners.push(listener) + return { dispose: vi.fn() } + }), + layout: vi.fn(), + getAction: vi.fn(() => ({ run: vi.fn() })), + getModel: vi.fn(() => ({ + getLineContent: vi.fn(() => ''), + })), + getPosition: vi.fn(() => ({ lineNumber: 1, column: 1 })), + deltaDecorations: vi.fn(() => []), + focus: vi.fn(() => { + focusListeners.forEach(listener => listener()) + }), + setPosition: vi.fn(), + revealLine: vi.fn(), + trigger: vi.fn(), + __blur: () => { + blurListeners.forEach(listener => listener()) + }, + } + } + + const monacoMock = { + editor: { + setTheme: vi.fn(), + defineTheme: vi.fn(), + }, + Range: class { + startLineNumber: number + startColumn: number + endLineNumber: number + endColumn: number + constructor(startLineNumber: number, startColumn: number, endLineNumber: number, endColumn: number) { + this.startLineNumber = startLineNumber + this.startColumn = startColumn + this.endLineNumber = endLineNumber + this.endColumn = endColumn + } + }, + } + + const MonacoEditor = ({ + value = '', + onChange, + onMount, + options, + }: { + value?: string + onChange?: (value: string | undefined) => void + onMount?: (editor: ReturnType, monaco: typeof monacoMock) => void + options?: { readOnly?: boolean } + }) => { + const editorRef = React.useRef | null>(null) + if (!editorRef.current) + editorRef.current = createEditorMock() + + React.useEffect(() => { + onMount?.(editorRef.current!, monacoMock) + }, [onMount]) + + return React.createElement('textarea', { + 'data-testid': 'monaco-editor', + 'readOnly': options?.readOnly, + value, + 'onChange': (event: React.ChangeEvent) => onChange?.(event.target.value), + 'onFocus': () => editorRef.current?.focus(), + 'onBlur': () => editorRef.current?.__blur(), + }) + } + + return { + __esModule: true, + default: MonacoEditor, + Editor: MonacoEditor, + loader: { + config: vi.fn(), + init: vi.fn().mockResolvedValue(monacoMock), + }, + } }) // Mock localStorage for testing